Skip to content

Commit 2ca747d

Browse files
authored
[AVC] Fix parallel running of evals (Azure#14595)
* Fix parallel running of evals. * Code review feedback.
1 parent e7cacd0 commit 2ca747d

File tree

3 files changed

+72
-31
lines changed

3 files changed

+72
-31
lines changed

packages/python-packages/apiview-copilot/evals/_runner.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
load_recordings,
2727
save_recordings,
2828
)
29+
from src._credential import warm_up_credential
2930
from src._settings import SettingsManager
3031

3132
DEFAULT_NUM_RUNS: int = 1
@@ -41,7 +42,7 @@ def __init__(self):
4142
"resource_group_name": self.settings.get("EVALS_RG"),
4243
"project_name": self.settings.get("EVALS_PROJECT_NAME"),
4344
}
44-
self._credential_kwargs = self._create_credential_kwargs()
45+
self.credential_kwargs = self._create_credential_kwargs()
4546
self._temp_files: list[Path] = []
4647
self._temp_files_lock = threading.Lock()
4748

@@ -65,7 +66,7 @@ def _create_credential_kwargs(self) -> dict[str, Any]:
6566
def in_ci(self) -> bool:
6667
return bool(os.getenv("TF_BUILD"))
6768

68-
def _load_test_file(self, test_file: Path) -> dict:
69+
def load_test_file(self, test_file: Path) -> dict:
6970
"""Load test file - supports both JSON and YAML formats."""
7071
try:
7172
with test_file.open("r", encoding="utf-8") as f:
@@ -145,6 +146,10 @@ def __init__(self, *, num_runs: int = DEFAULT_NUM_RUNS, use_recording: bool = Fa
145146
def _ensure_context(self):
146147
if self._context is None:
147148
self._context = ExecutionContext()
149+
# Pre-acquire a token so parallel workers find it cached
150+
# instead of all racing to spawn az-cli subprocesses.
151+
if not self._context.in_ci():
152+
warm_up_credential()
148153

149154
def run(self, discovery_result: DiscoveryResult) -> list[EvaluationResult]:
150155
"""Execute all targets in the discovery result.
@@ -164,8 +169,10 @@ def run(self, discovery_result: DiscoveryResult) -> list[EvaluationResult]:
164169
def _run(self, discovery_result: DiscoveryResult) -> list[EvaluationResult]:
165170
"""Run tests in parallel with progress tracking."""
166171
workflow_count = len(discovery_result.targets)
167-
cpu_count = os.cpu_count() or 4
168-
max_workers = min(cpu_count * 2, workflow_count)
172+
# Limit concurrency to avoid overwhelming credential token
173+
# acquisition (AzureCliCredential subprocess calls fail under
174+
# heavy parallelism).
175+
max_workers = min(4, workflow_count)
169176
results = []
170177
total_targets = len(discovery_result.targets)
171178

@@ -219,7 +226,7 @@ def _execute_target(self, target: EvaluationTarget) -> EvaluationResult:
219226
test_file_paths = []
220227

221228
for test_file in target.test_files:
222-
test_case = self._context._load_test_file(test_file)
229+
test_case = self._context.load_test_file(test_file)
223230
test_file_to_case[test_file] = test_case
224231
testcase_id = test_case.get("testcase")
225232
if testcase_id:
@@ -300,7 +307,7 @@ def _run_azure_evaluation(self, testcases: list[dict], target: EvaluationTarget)
300307
evaluator_config={"metrics": evaluator.evaluator_config},
301308
target=evaluator.target_function,
302309
fail_on_evaluator_errors=False,
303-
**self._context._credential_kwargs,
310+
**self._context.credential_kwargs,
304311
)
305312
results.append(result)
306313

@@ -341,7 +348,7 @@ def show_results(self, results: list[EvaluationResult]):
341348
passed_tests = []
342349
partial_tests = []
343350
raw = result.raw_results[0]
344-
for filename, eval_result in raw.items():
351+
for _, eval_result in raw.items():
345352
for res in eval_result["rows"]:
346353
testcase = res.get("inputs.testcase", "unknown")
347354
score = res.get("outputs.metrics.score")

packages/python-packages/apiview-copilot/src/_credential.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
"""Module for retrieving Azure credentials."""
88

9+
import logging
910
import os
11+
import threading
1012

1113
from azure.identity import (
1214
AzureCliCredential,
@@ -16,29 +18,60 @@
1618
ManagedIdentityCredential,
1719
)
1820

21+
logger = logging.getLogger(__name__)
22+
23+
_credential_cache = {"instance": None}
24+
_credential_lock = threading.Lock()
25+
1926

2027
def in_ci():
2128
"""Check if the code is running in a CI environment."""
2229
return os.getenv("TF_BUILD", None) and "tests" in os.getenv("SYSTEM_DEFINITIONNAME", "")
2330

2431

2532
def get_credential():
26-
"""Get Azure credentials based on the environment."""
27-
if in_ci():
28-
# These are used by Azure Pipelines and should not be changed
29-
service_connection_id = os.environ["AZURESUBSCRIPTION_SERVICE_CONNECTION_ID"]
30-
client_id = os.environ["AZURESUBSCRIPTION_CLIENT_ID"]
31-
tenant_id = os.environ["AZURESUBSCRIPTION_TENANT_ID"]
32-
system_access_token = os.environ["SYSTEM_ACCESSTOKEN"]
33-
return AzurePipelinesCredential(
34-
service_connection_id=service_connection_id,
35-
client_id=client_id,
36-
tenant_id=tenant_id,
37-
system_access_token=system_access_token,
38-
)
39-
40-
return ChainedTokenCredential(
41-
ManagedIdentityCredential(),
42-
AzureCliCredential(),
43-
AzureDeveloperCliCredential(),
44-
)
33+
"""Get a shared Azure credential instance.
34+
35+
Returns a cached singleton so that concurrent threads reuse the same
36+
credential instead of each spawning their own token-acquisition
37+
subprocesses (which fails under high parallelism).
38+
"""
39+
if _credential_cache["instance"] is not None:
40+
return _credential_cache["instance"]
41+
42+
with _credential_lock:
43+
if _credential_cache["instance"] is not None:
44+
return _credential_cache["instance"]
45+
46+
if in_ci():
47+
service_connection_id = os.environ["AZURESUBSCRIPTION_SERVICE_CONNECTION_ID"]
48+
client_id = os.environ["AZURESUBSCRIPTION_CLIENT_ID"]
49+
tenant_id = os.environ["AZURESUBSCRIPTION_TENANT_ID"]
50+
system_access_token = os.environ["SYSTEM_ACCESSTOKEN"]
51+
_credential_cache["instance"] = AzurePipelinesCredential(
52+
service_connection_id=service_connection_id,
53+
client_id=client_id,
54+
tenant_id=tenant_id,
55+
system_access_token=system_access_token,
56+
)
57+
else:
58+
_credential_cache["instance"] = ChainedTokenCredential(
59+
ManagedIdentityCredential(),
60+
AzureCliCredential(),
61+
AzureDeveloperCliCredential(),
62+
)
63+
64+
return _credential_cache["instance"]
65+
66+
67+
def warm_up_credential():
68+
"""Pre-acquire a token so it is cached before parallel workers start.
69+
70+
This prevents a thundering-herd of concurrent subprocess calls to
71+
``az account get-access-token`` that fail under high parallelism.
72+
"""
73+
credential = get_credential()
74+
try:
75+
credential.get_token("https://cognitiveservices.azure.com/.default")
76+
except Exception as exc:
77+
logger.warning("Credential warm-up failed: %s", exc)

packages/python-packages/apiview-copilot/src/_prompt_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def _execute_prompt_template(
153153
file_path: Path to the .prompty file.
154154
inputs: Dictionary of input variables for template rendering.
155155
configuration: Optional configuration dict. If it contains an
156-
``api_key`` entry, an ``AzureKeyCredential`` is used instead
157-
of ``DefaultAzureCredential``.
156+
``api_key`` entry, an ``AzureKeyCredential`` is used; otherwise,
157+
the shared credential from ``get_credential()`` is used.
158158
159159
Returns:
160160
The string response content from the model.
@@ -165,7 +165,7 @@ def _execute_prompt_template(
165165
from azure.ai.inference import ChatCompletionsClient
166166
from azure.ai.inference.models import SystemMessage, UserMessage
167167
from azure.core.credentials import AzureKeyCredential
168-
from azure.identity import DefaultAzureCredential
168+
from src._credential import get_credential
169169
from src._settings import SettingsManager
170170

171171
config = _parse_prompty(file_path)
@@ -193,7 +193,8 @@ def _execute_prompt_template(
193193
# Format: {FOUNDRY_ENDPOINT}/models
194194
inference_endpoint = f"{foundry_endpoint.rstrip('/')}/models"
195195

196-
# Authenticate — prefer an explicit API key (used in CI), fall back to DefaultAzureCredential
196+
# Authenticate — if an explicit API key is provided (e.g., in CI), use AzureKeyCredential;
197+
# otherwise, fall back to the shared credential from get_credential().
197198
api_key = (configuration or {}).get("api_key")
198199
if api_key:
199200
credential = AzureKeyCredential(api_key)
@@ -202,7 +203,7 @@ def _execute_prompt_template(
202203
credential=credential,
203204
)
204205
else:
205-
credential = DefaultAzureCredential()
206+
credential = get_credential()
206207
# Specify the cognitive services scope for Azure AI
207208
client = ChatCompletionsClient(
208209
endpoint=inference_endpoint,

0 commit comments

Comments
 (0)