|
| 1 | +# pylint: disable=line-too-long,useless-suppression |
| 2 | +# ------------------------------------ |
| 3 | +# Copyright (c) Microsoft Corporation. |
| 4 | +# Licensed under the MIT License. |
| 5 | +# ------------------------------------ |
| 6 | +"""Shared base code for sample tests - sync dependencies only.""" |
| 7 | +import os |
| 8 | +import sys |
| 9 | +import pytest |
| 10 | +import inspect |
| 11 | +import importlib.util |
| 12 | +from typing import Optional |
| 13 | +from pydantic import BaseModel |
| 14 | + |
| 15 | + |
| 16 | +class BaseSampleExecutor: |
| 17 | + """Base helper class for executing sample files with proper environment setup. |
| 18 | +
|
| 19 | + This class contains all shared logic that doesn't require async/aio imports. |
| 20 | + Subclasses implement sync/async specific credential and execution logic. |
| 21 | + """ |
| 22 | + |
| 23 | + class TestReport(BaseModel): |
| 24 | + """Schema for validation test report.""" |
| 25 | + |
| 26 | + model_config = {"extra": "forbid"} |
| 27 | + correct: bool |
| 28 | + reason: str |
| 29 | + |
| 30 | + def __init__(self, test_instance, sample_path: str, env_var_mapping: dict[str, str], **kwargs): |
| 31 | + self.test_instance = test_instance |
| 32 | + self.sample_path = sample_path |
| 33 | + self.print_calls: list[str] = [] |
| 34 | + self._original_print = print |
| 35 | + |
| 36 | + # Prepare environment variables |
| 37 | + self.env_vars = {} |
| 38 | + for sample_var, test_var in env_var_mapping.items(): |
| 39 | + value = kwargs.pop(test_var, None) |
| 40 | + if value is not None: |
| 41 | + self.env_vars[sample_var] = value |
| 42 | + |
| 43 | + # Add the sample's directory to sys.path so it can import local modules |
| 44 | + self.sample_dir = os.path.dirname(sample_path) |
| 45 | + if self.sample_dir not in sys.path: |
| 46 | + sys.path.insert(0, self.sample_dir) |
| 47 | + |
| 48 | + # Create module spec for dynamic import |
| 49 | + module_name = os.path.splitext(os.path.basename(self.sample_path))[0] |
| 50 | + spec = importlib.util.spec_from_file_location(module_name, self.sample_path) |
| 51 | + if spec is None or spec.loader is None: |
| 52 | + raise ImportError(f"Could not load module {module_name} from {self.sample_path}") |
| 53 | + |
| 54 | + self.module = importlib.util.module_from_spec(spec) |
| 55 | + self.spec = spec |
| 56 | + |
| 57 | + def _capture_print(self, *args, **kwargs): |
| 58 | + """Capture print calls while still outputting to console.""" |
| 59 | + self.print_calls.append(" ".join(str(arg) for arg in args)) |
| 60 | + self._original_print(*args, **kwargs) |
| 61 | + |
| 62 | + def _get_validation_request_params(self) -> dict: |
| 63 | + """Get common parameters for validation request.""" |
| 64 | + return { |
| 65 | + "model": "gpt-4o", |
| 66 | + "instructions": """We just run Python code and captured a Python array of print statements. |
| 67 | +Validating the printed content to determine if correct or not: |
| 68 | +Respond false if any entries show: |
| 69 | +- Error messages or exception text |
| 70 | +- Empty or null results where data is expected |
| 71 | +- Malformed or corrupted data |
| 72 | +- Timeout or connection errors |
| 73 | +- Warning messages indicating failures |
| 74 | +- Failure to retrieve or process data |
| 75 | +- Statements saying documents/information didn't provide relevant data |
| 76 | +- Statements saying unable to find/retrieve information |
| 77 | +- Asking the user to specify, clarify, or provide more details |
| 78 | +- Suggesting to use other tools or sources |
| 79 | +- Asking follow-up questions to complete the task |
| 80 | +- Indicating lack of knowledge or missing information |
| 81 | +- Responses that defer answering or redirect the question |
| 82 | +Respond with true only if the result provides a complete, substantive answer with actual data/information. |
| 83 | +Always respond with `reason` indicating the reason for the response.""", |
| 84 | + "text": { |
| 85 | + "format": { |
| 86 | + "type": "json_schema", |
| 87 | + "name": "TestReport", |
| 88 | + "schema": self.TestReport.model_json_schema(), |
| 89 | + } |
| 90 | + }, |
| 91 | + # The input field is sanitized in recordings (see conftest.py) by matching the unique prefix |
| 92 | + # "print contents array = ". This allows sample print statements to change without breaking playback. |
| 93 | + # The instructions field is preserved as-is in recordings. If you modify the instructions, |
| 94 | + # you must re-record the tests. |
| 95 | + "input": f"print contents array = {self.print_calls}", |
| 96 | + } |
| 97 | + |
| 98 | + def _assert_validation_result(self, test_report: dict) -> None: |
| 99 | + """Assert validation result and print reason.""" |
| 100 | + if not test_report["correct"]: |
| 101 | + # Write print statements to log file in temp folder for debugging |
| 102 | + import tempfile |
| 103 | + from datetime import datetime |
| 104 | + |
| 105 | + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 106 | + log_file = os.path.join(tempfile.gettempdir(), f"sample_validation_error_{timestamp}.log") |
| 107 | + with open(log_file, "w") as f: |
| 108 | + f.write(f"Sample: {self.sample_path}\n") |
| 109 | + f.write(f"Validation Error: {test_report['reason']}\n\n") |
| 110 | + f.write("Print Statements:\n") |
| 111 | + f.write("=" * 80 + "\n") |
| 112 | + for i, print_call in enumerate(self.print_calls, 1): |
| 113 | + f.write(f"{i}. {print_call}\n") |
| 114 | + print(f"\nValidation failed! Print statements logged to: {log_file}") |
| 115 | + assert test_report["correct"], f"Error is identified: {test_report['reason']}" |
| 116 | + print(f"Reason: {test_report['reason']}") |
| 117 | + |
| 118 | + |
| 119 | +class SamplePathPasser: |
| 120 | + """Decorator for passing sample path to test functions.""" |
| 121 | + |
| 122 | + def __call__(self, fn): |
| 123 | + if inspect.iscoroutinefunction(fn): |
| 124 | + |
| 125 | + async def _wrapper_async(test_class, sample_path, **kwargs): |
| 126 | + return await fn(test_class, sample_path, **kwargs) |
| 127 | + |
| 128 | + return _wrapper_async |
| 129 | + else: |
| 130 | + |
| 131 | + def _wrapper_sync(test_class, sample_path, **kwargs): |
| 132 | + return fn(test_class, sample_path, **kwargs) |
| 133 | + |
| 134 | + return _wrapper_sync |
| 135 | + |
| 136 | + |
| 137 | +def get_sample_paths( |
| 138 | + sub_folder: str, |
| 139 | + *, |
| 140 | + samples_to_skip: Optional[list[str]] = None, |
| 141 | + is_async: Optional[bool] = False, |
| 142 | +) -> list: |
| 143 | + """Get list of sample paths for testing.""" |
| 144 | + # Get the path to the samples folder |
| 145 | + current_dir = os.path.dirname(os.path.abspath(__file__)) |
| 146 | + samples_folder_path = os.path.normpath(os.path.join(current_dir, os.pardir, os.pardir)) |
| 147 | + target_folder = os.path.join(samples_folder_path, "samples", *sub_folder.split("/")) |
| 148 | + |
| 149 | + if not os.path.exists(target_folder): |
| 150 | + raise ValueError(f"Target folder does not exist: {target_folder}") |
| 151 | + |
| 152 | + print("Target folder for samples:", target_folder) |
| 153 | + print("is_async:", is_async) |
| 154 | + print("samples_to_skip:", samples_to_skip) |
| 155 | + # Discover all sync or async sample files in the folder |
| 156 | + all_files = [ |
| 157 | + f |
| 158 | + for f in os.listdir(target_folder) |
| 159 | + if ( |
| 160 | + f.startswith("sample_") |
| 161 | + and (f.endswith("_async.py") if is_async else (f.endswith(".py") and not f.endswith("_async.py"))) |
| 162 | + ) |
| 163 | + ] |
| 164 | + |
| 165 | + if samples_to_skip: |
| 166 | + files_to_test = [f for f in all_files if f not in samples_to_skip] |
| 167 | + else: |
| 168 | + files_to_test = all_files |
| 169 | + |
| 170 | + print(f"Running the following samples as test:\n{files_to_test}") |
| 171 | + |
| 172 | + # Create pytest.param objects |
| 173 | + samples = [] |
| 174 | + for filename in sorted(files_to_test): |
| 175 | + sample_path = os.path.join(target_folder, filename) |
| 176 | + test_id = filename.replace(".py", "") |
| 177 | + samples.append(pytest.param(sample_path, id=test_id)) |
| 178 | + |
| 179 | + return samples |
| 180 | + |
| 181 | + |
| 182 | +def get_sample_environment_variables_map(operation_group: Optional[str] = None) -> dict[str, str]: |
| 183 | + """Get the mapping of sample environment variables to test environment variables. |
| 184 | +
|
| 185 | + Args: |
| 186 | + operation_group: Optional operation group name (e.g., "agents") to scope the endpoint variable. |
| 187 | +
|
| 188 | + Returns: |
| 189 | + Dictionary mapping sample env var names to test env var names. |
| 190 | + """ |
| 191 | + return { |
| 192 | + "AZURE_AI_PROJECT_ENDPOINT": ( |
| 193 | + "azure_ai_projects_tests_project_endpoint" |
| 194 | + if operation_group is None |
| 195 | + else f"azure_ai_projects_tests_{operation_group}_project_endpoint" |
| 196 | + ), |
| 197 | + "AZURE_AI_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_model_deployment_name", |
| 198 | + "IMAGE_GENERATION_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_image_generation_model_deployment_name", |
| 199 | + "AI_SEARCH_PROJECT_CONNECTION_ID": "azure_ai_projects_tests_ai_search_project_connection_id", |
| 200 | + "AI_SEARCH_INDEX_NAME": "azure_ai_projects_tests_ai_search_index_name", |
| 201 | + "AI_SEARCH_USER_INPUT": "azure_ai_projects_tests_ai_search_user_input", |
| 202 | + "SHAREPOINT_USER_INPUT": "azure_ai_projects_tests_sharepoint_user_input", |
| 203 | + "SHAREPOINT_PROJECT_CONNECTION_ID": "azure_ai_projects_tests_sharepoint_project_connection_id", |
| 204 | + "MEMORY_STORE_CHAT_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_memory_store_chat_model_deployment_name", |
| 205 | + "MEMORY_STORE_EMBEDDING_MODEL_DEPLOYMENT_NAME": "azure_ai_projects_tests_memory_store_embedding_model_deployment_name", |
| 206 | + } |
0 commit comments