Skip to content

Commit 6f7b730

Browse files
committed
feat: add AI module for LLM interaction and a heuristic for checking code–docstring consistency
Signed-off-by: Amine <[email protected]>
1 parent 6aa7a4c commit 6f7b730

File tree

10 files changed

+572
-3
lines changed

10 files changed

+572
-3
lines changed

src/macaron/ai.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
3+
4+
"""This module provides a client for interacting with a Large Language Model (LLM)."""
5+
6+
import json
7+
import logging
8+
import re
9+
from typing import Any, TypeVar
10+
11+
from pydantic import BaseModel, ValidationError
12+
13+
from macaron.config.defaults import defaults
14+
from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError
15+
from macaron.util import send_post_http_raw
16+
17+
logger: logging.Logger = logging.getLogger(__name__)
18+
19+
T = TypeVar("T", bound=BaseModel)
20+
21+
22+
class AIClient:
23+
"""A client for interacting with a Large Language Model."""
24+
25+
def __init__(self, system_prompt: str):
26+
"""
27+
Initialize the AI client.
28+
29+
The LLM configuration (enabled, API key, endpoint, model) is read from defaults.
30+
"""
31+
self.enabled, self.api_endpoint, self.api_key, self.model, self.context_window = self._load_defaults()
32+
self.system_prompt = system_prompt.strip() or "You are a helpful AI assistant."
33+
logger.info("AI client is %s.", "enabled" if self.enabled else "disabled")
34+
35+
def _load_defaults(self) -> tuple[bool, str, str, str, int]:
36+
"""Load the LLM configuration from the defaults."""
37+
section_name = "llm"
38+
enabled, api_key, api_endpoint, model, context_window = False, "", "", "", 10000
39+
40+
if defaults.has_section(section_name):
41+
section = defaults[section_name]
42+
enabled = section.get("enabled", "False").strip().lower() == "true"
43+
api_key = section.get("api_key", "").strip()
44+
api_endpoint = section.get("api_endpoint", "").strip()
45+
model = section.get("model", "").strip()
46+
context_window = section.getint("context_window", 10000)
47+
48+
if enabled:
49+
if not api_key:
50+
raise ConfigurationError("API key for the AI client is not configured.")
51+
if not api_endpoint:
52+
raise ConfigurationError("API endpoint for the AI client is not configured.")
53+
if not model:
54+
raise ConfigurationError("Model for the AI client is not configured.")
55+
56+
return enabled, api_endpoint, api_key, model, context_window
57+
58+
def _validate_response(self, response_text: str, response_model: type[T]) -> T:
59+
"""
60+
Validate and parse the response from the LLM.
61+
62+
If raw JSON parsing fails, attempts to extract a JSON object from text.
63+
64+
Parameters
65+
----------
66+
response_text: str
67+
The response text from the LLM.
68+
response_model: Type[T]
69+
The Pydantic model to validate the response against.
70+
71+
Returns
72+
-------
73+
bool
74+
The validated Pydantic model instance.
75+
76+
Raises
77+
------
78+
HeuristicAnalyzerValueError
79+
If there is an error in parsing or validating the response.
80+
"""
81+
try:
82+
data = json.loads(response_text)
83+
except json.JSONDecodeError:
84+
logger.debug("Full JSON parse failed; trying to extract JSON from text.")
85+
# If the response is not a valid JSON, try to extract a JSON object from the text.
86+
match = re.search(r"\{.*\}", response_text, re.DOTALL)
87+
if not match:
88+
raise HeuristicAnalyzerValueError("No JSON object found in the LLM response.") from match
89+
try:
90+
data = json.loads(match.group(0))
91+
except json.JSONDecodeError as e:
92+
logger.error("Failed to parse extracted JSON: %s", e)
93+
raise HeuristicAnalyzerValueError("Invalid JSON extracted from response.") from e
94+
95+
try:
96+
return response_model.model_validate(data)
97+
except ValidationError as e:
98+
logger.error("Validation failed against response model: %s", e)
99+
raise HeuristicAnalyzerValueError("Response JSON validation failed.") from e
100+
101+
def invoke(
102+
self,
103+
user_prompt: str,
104+
temperature: float = 0.2,
105+
max_tokens: int = 4000,
106+
structured_output: type[T] | None = None,
107+
timeout: int = 30,
108+
) -> Any:
109+
"""
110+
Invoke the LLM and optionally validate its response.
111+
112+
Parameters
113+
----------
114+
user_prompt: str
115+
The user prompt to send to the LLM.
116+
temperature: float
117+
The temperature for the LLM response.
118+
max_tokens: int
119+
The maximum number of tokens for the LLM response.
120+
structured_output: Optional[Type[T]]
121+
The Pydantic model to validate the response against. If provided, the response will be parsed and validated.
122+
timeout: int
123+
The timeout for the HTTP request in seconds.
124+
125+
Returns
126+
-------
127+
Optional[T | str]
128+
The validated Pydantic model instance if `structured_output` is provided,
129+
or the raw string response if not.
130+
131+
Raises
132+
------
133+
HeuristicAnalyzerValueError
134+
If there is an error in parsing or validating the response.
135+
"""
136+
if not self.enabled:
137+
raise ConfigurationError("AI client is not enabled. Please check your configuration.")
138+
139+
if len(user_prompt.split()) > self.context_window:
140+
logger.warning(
141+
"User prompt exceeds context window (%s words). "
142+
"Truncating the prompt to fit within the context window.",
143+
self.context_window,
144+
)
145+
user_prompt = " ".join(user_prompt.split()[: self.context_window])
146+
147+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
148+
payload = {
149+
"model": self.model,
150+
"messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}],
151+
"temperature": temperature,
152+
"max_tokens": max_tokens,
153+
}
154+
155+
try:
156+
response = send_post_http_raw(url=self.api_endpoint, json_data=payload, headers=headers, timeout=timeout)
157+
if not response:
158+
raise HeuristicAnalyzerValueError("No response received from the LLM.")
159+
response_json = response.json()
160+
usage = response_json.get("usage", {})
161+
162+
if usage:
163+
usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items())
164+
logger.info("LLM call token usage: %s", usage_str)
165+
166+
message_content = response_json["choices"][0]["message"]["content"]
167+
168+
if not structured_output:
169+
logger.debug("Returning raw message content (no structured output requested).")
170+
return message_content
171+
return self._validate_response(message_content, structured_output)
172+
173+
except Exception as e:
174+
logger.error("Error during LLM invocation: %s", e)
175+
raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e

src/macaron/config/defaults.ini

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,3 +632,17 @@ custom_semgrep_rules_path =
632632
# .yaml prefix. Note, this will be ignored if a path to custom semgrep rules is not provided. This list may not contain
633633
# duplicated elements, meaning that ruleset names must be unique.
634634
disabled_custom_rulesets =
635+
636+
[llm]
637+
# The LLM configuration for Macaron.
638+
# If enabled, the LLM will be used to analyze the results and provide insights.
639+
enabled =
640+
# The API key for the LLM service.
641+
api_key =
642+
# The API endpoint for the LLM service.
643+
api_endpoint =
644+
# The model to use for the LLM service.
645+
model =
646+
# The context window size for the LLM service.
647+
# This is the maximum number of tokens that the LLM can process in a single request.
648+
context_window = 10000

src/macaron/malware_analyzer/pypi_heuristics/heuristics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class Heuristics(str, Enum):
4949
#: Indicates that the package has a similar structure to other packages maintained by the same user.
5050
SIMILAR_PROJECTS = "similar_projects"
5151

52+
#: Indicates that the package contains some code that doesn't match the docstrings.
53+
MATCHING_DOCSTRINGS = "matching_docstrings"
54+
5255

5356
class HeuristicResult(str, Enum):
5457
"""Result type indicating the outcome of a heuristic."""
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
3+
4+
"""This analyzer checks the iconsistency of code with its docstrings."""
5+
6+
import logging
7+
import time
8+
from typing import Literal
9+
10+
from pydantic import BaseModel, Field
11+
12+
from macaron.ai import AIClient
13+
from macaron.json_tools import JsonType
14+
from macaron.malware_analyzer.pypi_heuristics.base_analyzer import BaseHeuristicAnalyzer
15+
from macaron.malware_analyzer.pypi_heuristics.heuristics import HeuristicResult, Heuristics
16+
from macaron.slsa_analyzer.package_registry.pypi_registry import PyPIPackageJsonAsset
17+
18+
logger: logging.Logger = logging.getLogger(__name__)
19+
20+
21+
class Result(BaseModel):
22+
"""The result after analysing the code with its docstrings."""
23+
24+
decision: Literal["consistent", "inconsistent"] = Field(
25+
description=""" The final decision after analysing the code with its docstrings.
26+
It can be either 'consistent' or 'inconsistent'."""
27+
)
28+
reason: str = Field(
29+
description=" The reason for the decision made. It should be a short sentence explaining the decision."
30+
)
31+
inconsistent_code_part: str | None = Field(
32+
default=None,
33+
description=""" The specific part of the code that is inconsistent with the docstring.
34+
Empty if the decision is 'consistent'.""",
35+
)
36+
37+
38+
class MatchingDocstringsAnalyzer(BaseHeuristicAnalyzer):
39+
"""Check whether the docstrings and the code components are consistent."""
40+
41+
SYSTEM_PROMPT = """
42+
You are a code master who can detect the inconsistency of the code with the docstrings that describes its components.
43+
You will be given a python code file. Your task is to determine whether the code is consistent with the docstrings.
44+
Wrap the output in `json` tags.
45+
Your response must be a JSON object matching this schema:
46+
{
47+
"decision": "'consistent' or 'inconsistent'",
48+
"reason": "A short explanation.", "inconsistent_code_part":
49+
"The inconsistent code, or null."
50+
}
51+
52+
/no_think
53+
"""
54+
55+
REQUEST_INTERVAL = 0.5
56+
57+
def __init__(self) -> None:
58+
super().__init__(
59+
name="matching_docstrings_analyzer",
60+
heuristic=Heuristics.MATCHING_DOCSTRINGS,
61+
depends_on=None,
62+
)
63+
self.client = AIClient(system_prompt=self.SYSTEM_PROMPT.strip())
64+
65+
def analyze(self, pypi_package_json: PyPIPackageJsonAsset) -> tuple[HeuristicResult, dict[str, JsonType]]:
66+
"""Analyze the package.
67+
68+
Parameters
69+
----------
70+
pypi_package_json: PyPIPackageJsonAsset
71+
The PyPI package JSON asset object.
72+
73+
Returns
74+
-------
75+
tuple[HeuristicResult, dict[str, JsonType]]:
76+
The result and related information collected during the analysis.
77+
"""
78+
if not self.client.enabled:
79+
logger.warning("AI client is not enabled, skipping the matching docstrings analysis.")
80+
return HeuristicResult.SKIP, {}
81+
82+
download_result = pypi_package_json.download_sourcecode()
83+
if not download_result:
84+
logger.warning("No source code found for the package, skipping the matching docstrings analysis.")
85+
return HeuristicResult.SKIP, {}
86+
87+
for file, content in pypi_package_json.iter_sourcecode():
88+
if file.endswith(".py"):
89+
time.sleep(self.REQUEST_INTERVAL) # Respect the request interval to avoid rate limiting.
90+
code_str = content.decode("utf-8", "ignore")
91+
analysis_result = self.client.invoke(
92+
user_prompt=code_str,
93+
structured_output=Result,
94+
)
95+
if analysis_result and analysis_result.decision == "inconsistent":
96+
return HeuristicResult.FAIL, {
97+
"file": file,
98+
"reason": analysis_result.reason,
99+
"inconsistent part": analysis_result.inconsistent_code_part or "",
100+
}
101+
return HeuristicResult.PASS, {}

0 commit comments

Comments
 (0)