Skip to content

Commit abf68cd

Browse files
bug fixes
1 parent bf97c90 commit abf68cd

File tree

8 files changed

+476
-217
lines changed

8 files changed

+476
-217
lines changed

medarc_verifiers/cli/_manifest.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import json
6+
import logging
67
from collections import Counter
78
from dataclasses import dataclass
89
from datetime import UTC, datetime
@@ -20,6 +21,29 @@
2021
PROJECT_ROOT = project_root()
2122
MANIFEST_VERSION = 2
2223

24+
logger = logging.getLogger(__name__)
25+
26+
27+
class ManifestConflictError(ValueError):
28+
"""Raised when an existing manifest conflicts with the current config."""
29+
30+
31+
def _normalize_model_slug(value: str) -> str:
32+
"""Normalize model slugs for restart comparisons.
33+
34+
Some providers expose the same model under different namespaces (e.g.
35+
`google/gemini-3-pro-preview` vs `gemini-3-pro-preview`). For now, we only
36+
normalize Gemini model slugs by stripping a single leading namespace.
37+
"""
38+
if not value:
39+
return value
40+
if "/" not in value:
41+
return value
42+
candidate = value.rsplit("/", 1)[-1]
43+
if candidate.startswith("gemini-"):
44+
return candidate
45+
return value
46+
2347

2448
class ManifestJobEntry(BaseModel):
2549
"""Pydantic model describing a single manifest job entry."""
@@ -144,7 +168,71 @@ def _require_manifest_v2(payload: Mapping[str, Any], *, path: Path | None = None
144168

145169

146170
def _sanitize_model_payload(model_payload: Mapping[str, Any]) -> dict[str, Any]:
147-
return {key: value for key, value in model_payload.items() if key not in ModelConfigSchema.resume_tolerant_fields}
171+
sanitized = {key: value for key, value in model_payload.items() if key not in ModelConfigSchema.resume_tolerant_fields}
172+
173+
model_slug = sanitized.get("model")
174+
if isinstance(model_slug, str):
175+
sanitized["model"] = _normalize_model_slug(model_slug)
176+
177+
# Provider quirks: OpenAI-compatible endpoints vary widely in what they accept when
178+
# we forward `sampling_args.extra_body`. Treat *all* of extra_body as resume-tolerant
179+
# for the purposes of manifest conflict detection so users can switch providers
180+
# without getting blocked by payload drift.
181+
sampling_args = sanitized.get("sampling_args")
182+
if isinstance(sampling_args, Mapping):
183+
updated_sampling_args = dict(sampling_args)
184+
updated_sampling_args.pop("extra_body", None)
185+
if updated_sampling_args:
186+
sanitized["sampling_args"] = updated_sampling_args
187+
else:
188+
sanitized.pop("sampling_args", None)
189+
190+
return sanitized
191+
192+
193+
def _sampling_extra_body(model_payload: Mapping[str, Any]) -> dict[str, Any] | None:
194+
sampling_args = model_payload.get("sampling_args")
195+
if not isinstance(sampling_args, Mapping):
196+
return None
197+
extra_body = sampling_args.get("extra_body")
198+
if not isinstance(extra_body, Mapping):
199+
return None
200+
normalized = _normalize_payload(extra_body)
201+
return normalized or None
202+
203+
204+
def _warn_extra_body_change(key: str, existing: Mapping[str, Any], payload: Mapping[str, Any]) -> None:
205+
existing_extra = _sampling_extra_body(existing)
206+
payload_extra = _sampling_extra_body(payload)
207+
if existing_extra is None and payload_extra is None:
208+
return
209+
if compute_checksum(existing_extra or {}) == compute_checksum(payload_extra or {}):
210+
return
211+
logger.warning(
212+
"Model '%s' sampling_args.extra_body changed; allowing restart, but providers may reject unknown fields.",
213+
key,
214+
)
215+
216+
217+
def _sampling_args_payload(model_payload: Mapping[str, Any]) -> dict[str, Any] | None:
218+
sampling_args = model_payload.get("sampling_args")
219+
if not isinstance(sampling_args, Mapping):
220+
return None
221+
normalized = _normalize_payload(sampling_args)
222+
return normalized or None
223+
224+
225+
def _warn_sampling_args_change(key: str, existing: Mapping[str, Any], payload: Mapping[str, Any]) -> None:
226+
existing_sampling = _sampling_args_payload(existing)
227+
payload_sampling = _sampling_args_payload(payload)
228+
if existing_sampling is None and payload_sampling is None:
229+
return
230+
if compute_checksum(existing_sampling or {}) == compute_checksum(payload_sampling or {}):
231+
return
232+
logger.warning(
233+
"Model '%s' sampling_args changed; allowing restart, but providers may reject unsupported parameters.",
234+
key,
235+
)
148236

149237

150238
def _effective_sampling_args(entry: ManifestJobEntry, model_payload: Mapping[str, Any]) -> Mapping[str, Any]:
@@ -247,11 +335,28 @@ def _merge_unique_model_payload(
247335
if allow_mismatch:
248336
container[key] = payload
249337
return
250-
if _sanitize_model_payload(existing) == _sanitize_model_payload(payload):
338+
sanitized_existing = _sanitize_model_payload(existing)
339+
sanitized_payload = _sanitize_model_payload(payload)
340+
if sanitized_existing == sanitized_payload:
341+
_warn_extra_body_change(key, existing, payload)
251342
container[key] = payload
252343
return
253-
msg = f"Conflicting model payload for '{key}'."
254-
raise ValueError(msg)
344+
345+
stripped_existing = dict(sanitized_existing)
346+
stripped_payload = dict(sanitized_payload)
347+
stripped_existing.pop("sampling_args", None)
348+
stripped_payload.pop("sampling_args", None)
349+
if stripped_existing == stripped_payload:
350+
_warn_sampling_args_change(key, existing, payload)
351+
_warn_extra_body_change(key, existing, payload)
352+
container[key] = payload
353+
return
354+
355+
all_keys = set(sanitized_existing) | set(sanitized_payload)
356+
diff_keys = sorted(key for key in all_keys if sanitized_existing.get(key) != sanitized_payload.get(key))
357+
suffix = f" (conflicting keys: {', '.join(diff_keys)})" if diff_keys else ""
358+
msg = f"Conflicting model payload for '{key}'{suffix}."
359+
raise ManifestConflictError(msg)
255360

256361

257362
def _merge_unique_payload(

medarc_verifiers/cli/_schemas.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ class ModelConfigSchema(BaseModel):
2929

3030
resume_tolerant_fields: ClassVar[set[str]] = frozenset(
3131
{
32+
"api_key_var",
3233
"api_base_url",
34+
"endpoints_path",
35+
"headers",
3336
"timeout",
3437
"max_connections",
3538
"max_keepalive_connections",

medarc_verifiers/utils/judge_helpers.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,16 @@ def judge_sampling_args_and_headers(
193193

194194

195195
def default_judge_api_key(base_url: str | None = None) -> str | None:
196+
# Prefer an explicit judge key regardless of provider.
197+
if os.environ.get("JUDGE_API_KEY") is not None:
198+
return os.environ.get("JUDGE_API_KEY")
199+
200+
# If judging via Prime Inference and no explicit judge key is set, fall back to PRIME_API_KEY.
196201
if base_url == PRIME_INFERENCE_URL and os.environ.get("PRIME_API_KEY") is not None:
197202
return os.environ.get("PRIME_API_KEY")
198-
elif os.environ.get("OPENAI_API_KEY") is not None:
203+
204+
# Back-compat fallback for setups that only set OPENAI_API_KEY.
205+
if os.environ.get("OPENAI_API_KEY") is not None:
199206
return os.environ.get("OPENAI_API_KEY")
200-
elif os.environ.get("JUDGE_API_KEY") is not None:
201-
return os.environ.get("JUDGE_API_KEY")
202-
else:
203-
return None
207+
208+
return None

medarc_verifiers/utils/retry.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,55 @@ def _extract_retry_delay(exc: BaseException) -> float | None:
9595
return None
9696

9797

98+
def _extract_error_type_code(exc: BaseException) -> tuple[str | None, str | None]:
99+
"""Extract provider error `type` and `code` when present.
100+
101+
This is primarily used for OpenAI-compatible errors where `exc.body` or
102+
`exc.response` contains an `{"error": {"type": ..., "code": ...}}` payload.
103+
"""
104+
body = getattr(exc, "body", None)
105+
payload: Any = None
106+
107+
if isinstance(body, dict):
108+
payload = body
109+
elif isinstance(body, str):
110+
try:
111+
import json
112+
113+
payload = json.loads(body)
114+
except Exception:
115+
payload = None
116+
117+
if payload is None:
118+
resp = getattr(exc, "response", None)
119+
if resp is not None:
120+
try:
121+
payload = resp.json()
122+
except Exception:
123+
try:
124+
import json
125+
126+
payload = json.loads(getattr(resp, "text", "") or "")
127+
except Exception:
128+
payload = None
129+
130+
if isinstance(payload, list) and payload:
131+
payload = payload[0]
132+
if not isinstance(payload, dict):
133+
# Very coarse fallback: try to detect policy_violation in message text.
134+
text = str(exc)
135+
if "policy_violation" in text:
136+
return "policy_violation", None
137+
return None, None
138+
139+
err = payload.get("error")
140+
if not isinstance(err, dict):
141+
return None, None
142+
err_type = err.get("type")
143+
err_code = err.get("code")
144+
return (err_type if isinstance(err_type, str) else None, err_code if isinstance(err_code, str) else None)
145+
146+
98147
def should_retry_exception(exc: BaseException) -> tuple[bool, int | None, str | None, float | None]:
99148
"""Identify retryable exceptions from model calls."""
100149
if isinstance(exc, AssertionError):
@@ -103,6 +152,10 @@ def should_retry_exception(exc: BaseException) -> tuple[bool, int | None, str |
103152
return True, None, message, None
104153
status = _status_code(exc)
105154
retry_delay = _extract_retry_delay(exc) if status == 429 else None
155+
if status == 403:
156+
err_type, err_code = _extract_error_type_code(exc)
157+
if err_type == "policy_violation":
158+
return True, 403, f"HTTP 403 policy_violation: {err_code}", None
106159
if isinstance(exc, (BadRequestError, httpx.HTTPStatusError)):
107160
if status == 400:
108161
return True, 400, "HTTP 400 during model call", None
@@ -162,6 +215,9 @@ async def call_with_retries(
162215
result = await func()
163216
except Exception as exc: # noqa: BLE001
164217
retry, code, reason, retry_delay = should_retry_exception(exc)
218+
# 403 policy violations are not typically transient; allow only one extra attempt.
219+
if retry and code == 403 and attempt >= 2:
220+
retry = False
165221
if retry and attempt < attempts:
166222
delay = (
167223
retry_delay

0 commit comments

Comments
 (0)