Skip to content

Commit bb5dd77

Browse files
authored
Move all configuration to ApiViewReviewer constructor. (Azure#10395)
1 parent dbfa196 commit bb5dd77

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

packages/python-packages/apiview-copilot/cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def local_review(
5555
"""
5656
from src._apiview_reviewer import ApiViewReview
5757

58-
rg = ApiViewReview(language=language, model=model)
58+
rg = ApiViewReview(
59+
language=language, model=model, chunk_input=chunk_input, use_rag=use_rag
60+
)
5961
filename = os.path.splitext(os.path.basename(path))[0]
6062

6163
with open(path, "r") as f:
6264
apiview = f.read()
63-
review = rg.get_response(apiview, chunk_input=chunk_input, use_rag=use_rag)
65+
review = rg.get_response(apiview)
6466
output_path = os.path.join("scratch", "output", language)
6567
os.makedirs(output_path, exist_ok=True)
6668
output_file = os.path.join(output_path, f"{filename}.json")

packages/python-packages/apiview-copilot/evals/run.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
dotenv.load_dotenv()
1717

18-
MODEL = "o3-mini"
1918
NUM_RUNS: int = 3
2019

2120

@@ -101,8 +100,8 @@ def review_apiview(query: str, language: str):
101100
ApiViewReview,
102101
)
103102

104-
ai_review = ApiViewReview(language=language, model=MODEL)
105-
review = ai_review.get_response(query, chunk_input=False, use_rag=False)
103+
ai_review = ApiViewReview(language=language)
104+
review = ai_review.get_response(query)
106105
return {"response": review.model_dump_json()}
107106

108107

packages/python-packages/apiview-copilot/src/_apiview_reviewer.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,30 @@
6060
supported_models = [x for x in model_map.keys()]
6161

6262
DEFAULT_MODEL = "o3-mini"
63+
DEFAULT_USE_RAG = False
64+
DEFAULT_CHUNK_INPUT = False
6365

6466

6567
class ApiViewReview:
6668

67-
def __init__(self, *, language: str, model: str = DEFAULT_MODEL):
68-
self.language = language
69-
self.model = model
70-
self.search = SearchManager(language=language)
71-
self.output_parser = ReviewResult
72-
self.semantic_search_failed = False
69+
def __init__(
70+
self,
71+
*,
72+
language: str,
73+
model: str = DEFAULT_MODEL,
74+
use_rag: bool = DEFAULT_USE_RAG,
75+
chunk_input: bool = DEFAULT_CHUNK_INPUT,
76+
):
7377
if model not in supported_models:
7478
raise ValueError(
7579
f"Model {model} not supported. Supported models are: {', '.join(supported_models)}"
7680
)
81+
self.language = language
82+
self.model = model
83+
self.use_rag = use_rag
84+
self.chunk_input = chunk_input
85+
self.search = SearchManager(language=language)
86+
self.semantic_search_failed = False
7787

7888
def _hash(self, obj) -> str:
7989
return str(hash(json.dumps(obj)))
@@ -89,13 +99,11 @@ def _ensure_env_vars(self, vars: List[str]):
8999
if missing:
90100
raise ValueError(f"Environment variables not set: {', '.join(missing)}")
91101

92-
def get_response(
93-
self, apiview: str, *, chunk_input: bool = False, use_rag: bool = False
94-
) -> ReviewResult:
102+
def get_response(self, apiview: str) -> ReviewResult:
95103
print(f"Generating review...")
96104

97105
logger.info(
98-
f"Starting review with model: {self.model}, language: {self.language}, RAG: {use_rag}"
106+
f"Starting review with model: {self.model}, language: {self.language}, RAG: {self.use_rag}"
99107
)
100108

101109
start_time = time()
@@ -106,7 +114,9 @@ def get_response(
106114
static_guideline_ids = [x["id"] for x in static_guidelines]
107115

108116
# Prepare the document
109-
chunked_apiview = SectionedDocument(apiview.splitlines(), chunk=chunk_input)
117+
chunked_apiview = SectionedDocument(
118+
apiview.splitlines(), chunk=self.chunk_input
119+
)
110120
final_results = ReviewResult(
111121
guideline_ids=static_guideline_ids, status="Success", violations=[]
112122
)
@@ -179,7 +189,7 @@ def process_chunk(chunk_info):
179189

180190
try:
181191
# build the context string
182-
if use_rag:
192+
if self.use_rag:
183193
context = self._retrieve_and_resolve_guidelines(str(chunk))
184194
if context:
185195
context_string = context.to_markdown()

0 commit comments

Comments
 (0)