Skip to content

Commit dc7bdd6

Browse files
authored
Refactor the test sample functions to fix dependency on aiohttp package (Azure#44490)
1 parent b753242 commit dc7bdd6

File tree

4 files changed

+322
-260
lines changed

4 files changed

+322
-260
lines changed

sdk/ai/azure-ai-projects/samples/evaluations/sample_continuous_evaluation_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@
158158
if len(eval_run_list.data) > 0 and eval_run_list.data[0].report_url:
159159
run_report_url = eval_run_list.data[0].report_url
160160
# Remove the last 2 URL path segments (run/continuousevalrun_xxx)
161-
report_url = '/'.join(run_report_url.split('/')[:-2])
161+
report_url = "/".join(run_report_url.split("/")[:-2])
162162
print(f"To check evaluation runs, please open {report_url} from the browser")
163163
break
164164

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# pylint: disable=line-too-long,useless-suppression
2+
# ------------------------------------
3+
# Copyright (c) Microsoft Corporation.
4+
# Licensed under the MIT License.
5+
# ------------------------------------
6+
"""Shared base code for sample tests - sync dependencies only."""
7+
import os
8+
import sys
9+
import pytest
10+
import inspect
11+
import importlib.util
12+
from typing import Optional
13+
from pydantic import BaseModel
14+
15+
16+
class BaseSampleExecutor:
17+
"""Base helper class for executing sample files with proper environment setup.
18+
19+
This class contains all shared logic that doesn't require async/aio imports.
20+
Subclasses implement sync/async specific credential and execution logic.
21+
"""
22+
23+
class TestReport(BaseModel):
24+
"""Schema for validation test report."""
25+
26+
model_config = {"extra": "forbid"}
27+
correct: bool
28+
reason: str
29+
30+
def __init__(self, test_instance, sample_path: str, env_var_mapping: dict[str, str], **kwargs):
31+
self.test_instance = test_instance
32+
self.sample_path = sample_path
33+
self.print_calls: list[str] = []
34+
self._original_print = print
35+
36+
# Prepare environment variables
37+
self.env_vars = {}
38+
for sample_var, test_var in env_var_mapping.items():
39+
value = kwargs.pop(test_var, None)
40+
if value is not None:
41+
self.env_vars[sample_var] = value
42+
43+
# Add the sample's directory to sys.path so it can import local modules
44+
self.sample_dir = os.path.dirname(sample_path)
45+
if self.sample_dir not in sys.path:
46+
sys.path.insert(0, self.sample_dir)
47+
48+
# Create module spec for dynamic import
49+
module_name = os.path.splitext(os.path.basename(self.sample_path))[0]
50+
spec = importlib.util.spec_from_file_location(module_name, self.sample_path)
51+
if spec is None or spec.loader is None:
52+
raise ImportError(f"Could not load module {module_name} from {self.sample_path}")
53+
54+
self.module = importlib.util.module_from_spec(spec)
55+
self.spec = spec
56+
57+
def _capture_print(self, *args, **kwargs):
58+
"""Capture print calls while still outputting to console."""
59+
self.print_calls.append(" ".join(str(arg) for arg in args))
60+
self._original_print(*args, **kwargs)
61+
62+
def _get_validation_request_params(self) -> dict:
63+
"""Get common parameters for validation request."""
64+
return {
65+
"model": "gpt-4o",
66+
"instructions": """We just run Python code and captured a Python array of print statements.
67+
Validating the printed content to determine if correct or not:
68+
Respond false if any entries show:
69+
- Error messages or exception text
70+
- Empty or null results where data is expected
71+
- Malformed or corrupted data
72+
- Timeout or connection errors
73+
- Warning messages indicating failures
74+
- Failure to retrieve or process data
75+
- Statements saying documents/information didn't provide relevant data
76+
- Statements saying unable to find/retrieve information
77+
- Asking the user to specify, clarify, or provide more details
78+
- Suggesting to use other tools or sources
79+
- Asking follow-up questions to complete the task
80+
- Indicating lack of knowledge or missing information
81+
- Responses that defer answering or redirect the question
82+
Respond with true only if the result provides a complete, substantive answer with actual data/information.
83+
Always respond with `reason` indicating the reason for the response.""",
84+
"text": {
85+
"format": {
86+
"type": "json_schema",
87+
"name": "TestReport",
88+
"schema": self.TestReport.model_json_schema(),
89+
}
90+
},
91+
# The input field is sanitized in recordings (see conftest.py) by matching the unique prefix
92+
# "print contents array = ". This allows sample print statements to change without breaking playback.
93+
# The instructions field is preserved as-is in recordings. If you modify the instructions,
94+
# you must re-record the tests.
95+
"input": f"print contents array = {self.print_calls}",
96+
}
97+
98+
def _assert_validation_result(self, test_report: dict) -> None:
99+
"""Assert validation result and print reason."""
100+
if not test_report["correct"]:
101+
# Write print statements to log file in temp folder for debugging
102+
import tempfile
103+
from datetime import datetime
104+
105+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
106+
log_file = os.path.join(tempfile.gettempdir(), f"sample_validation_error_{timestamp}.log")
107+
with open(log_file, "w") as f:
108+
f.write(f"Sample: {self.sample_path}\n")
109+
f.write(f"Validation Error: {test_report['reason']}\n\n")
110+
f.write("Print Statements:\n")
111+
f.write("=" * 80 + "\n")
112+
for i, print_call in enumerate(self.print_calls, 1):
113+
f.write(f"{i}. {print_call}\n")
114+
print(f"\nValidation failed! Print statements logged to: {log_file}")
115+
assert test_report["correct"], f"Error is identified: {test_report['reason']}"
116+
print(f"Reason: {test_report['reason']}")
117+
118+
119+
class SamplePathPasser:
120+
"""Decorator for passing sample path to test functions."""
121+
122+
def __call__(self, fn):
123+
if inspect.iscoroutinefunction(fn):
124+
125+
async def _wrapper_async(test_class, sample_path, **kwargs):
126+
return await fn(test_class, sample_path, **kwargs)
127+
128+
return _wrapper_async
129+
else:
130+
131+
def _wrapper_sync(test_class, sample_path, **kwargs):
132+
return fn(test_class, sample_path, **kwargs)
133+
134+
return _wrapper_sync
135+
136+
137+
def get_sample_paths(
138+
sub_folder: str,
139+
*,
140+
samples_to_skip: Optional[list[str]] = None,
141+
is_async: Optional[bool] = False,
142+
) -> list:
143+
"""Get list of sample paths for testing."""
144+
# Get the path to the samples folder
145+
current_dir = os.path.dirname(os.path.abspath(__file__))
146+
samples_folder_path = os.path.normpath(os.path.join(current_dir, os.pardir, os.pardir))
147+
target_folder = os.path.join(samples_folder_path, "samples", *sub_folder.split("/"))
148+
149+
if not os.path.exists(target_folder):
150+
raise ValueError(f"Target folder does not exist: {target_folder}")
151+
152+
print("Target folder for samples:", target_folder)
153+
print("is_async:", is_async)
154+
print("samples_to_skip:", samples_to_skip)
155+
# Discover all sync or async sample files in the folder
156+
all_files = [
157+
f
158+
for f in os.listdir(target_folder)
159+
if (
160+
f.startswith("sample_")
161+
and (f.endswith("_async.py") if is_async else (f.endswith(".py") and not f.endswith("_async.py")))
162+
)
163+
]
164+
165+
if samples_to_skip:
166+
files_to_test = [f for f in all_files if f not in samples_to_skip]
167+
else:
168+
files_to_test = all_files
169+
170+
print(f"Running the following samples as test:\n{files_to_test}")
171+
172+
# Create pytest.param objects
173+
samples = []
174+
for filename in sorted(files_to_test):
175+
sample_path = os.path.join(target_folder, filename)
176+
test_id = filename.replace(".py", "")
177+
samples.append(pytest.param(sample_path, id=test_id))
178+
179+
return samples
180+
181+
182+
def get_sample_environment_variables_map(operation_group: Optional[str] = None) -> dict[str, str]:
183+
"""Get the mapping of sample environment variables to test environment variables.
184+
185+
Args:
186+
operation_group: Optional operation group name (e.g., "agents") to scope the endpoint variable.
187+
188+
Returns:
189+
Dictionary mapping sample env var names to test env var names.
190+
"""
191+
return {
192+
"AZURE_AI_PROJECT_ENDPOINT": (
193+
"azure_ai_projects_tests_project_endpoint"
194+
if operation_group is None
195+
else f"azure_ai_projects_tests_{operation_group}_project_endpoint"
196+
),
197+
"AZURE_AI_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_model_deployment_name",
198+
"IMAGE_GENERATION_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_image_generation_model_deployment_name",
199+
"AI_SEARCH_PROJECT_CONNECTION_ID": "azure_ai_projects_tests_ai_search_project_connection_id",
200+
"AI_SEARCH_INDEX_NAME": "azure_ai_projects_tests_ai_search_index_name",
201+
"AI_SEARCH_USER_INPUT": "azure_ai_projects_tests_ai_search_user_input",
202+
"SHAREPOINT_USER_INPUT": "azure_ai_projects_tests_sharepoint_user_input",
203+
"SHAREPOINT_PROJECT_CONNECTION_ID": "azure_ai_projects_tests_sharepoint_project_connection_id",
204+
"MEMORY_STORE_CHAT_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_memory_store_chat_model_deployment_name",
205+
"MEMORY_STORE_EMBEDDING_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_memory_store_embedding_model_deployment_name",
206+
}

0 commit comments

Comments
 (0)