|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import json |
| 6 | +import logging |
6 | 7 | from collections import Counter |
7 | 8 | from dataclasses import dataclass |
8 | 9 | from datetime import UTC, datetime |
|
20 | 21 | PROJECT_ROOT = project_root() |
21 | 22 | MANIFEST_VERSION = 2 |
22 | 23 |
|
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class ManifestConflictError(ValueError): |
| 28 | + """Raised when an existing manifest conflicts with the current config.""" |
| 29 | + |
| 30 | + |
| 31 | +def _normalize_model_slug(value: str) -> str: |
| 32 | + """Normalize model slugs for restart comparisons. |
| 33 | +
|
| 34 | + Some providers expose the same model under different namespaces (e.g. |
| 35 | + `google/gemini-3-pro-preview` vs `gemini-3-pro-preview`). For now, we only |
| 36 | + normalize Gemini model slugs by stripping a single leading namespace. |
| 37 | + """ |
| 38 | + if not value: |
| 39 | + return value |
| 40 | + if "/" not in value: |
| 41 | + return value |
| 42 | + candidate = value.rsplit("/", 1)[-1] |
| 43 | + if candidate.startswith("gemini-"): |
| 44 | + return candidate |
| 45 | + return value |
| 46 | + |
23 | 47 |
|
24 | 48 | class ManifestJobEntry(BaseModel): |
25 | 49 | """Pydantic model describing a single manifest job entry.""" |
@@ -144,7 +168,71 @@ def _require_manifest_v2(payload: Mapping[str, Any], *, path: Path | None = None |
144 | 168 |
|
145 | 169 |
|
146 | 170 | def _sanitize_model_payload(model_payload: Mapping[str, Any]) -> dict[str, Any]: |
147 | | - return {key: value for key, value in model_payload.items() if key not in ModelConfigSchema.resume_tolerant_fields} |
| 171 | + sanitized = {key: value for key, value in model_payload.items() if key not in ModelConfigSchema.resume_tolerant_fields} |
| 172 | + |
| 173 | + model_slug = sanitized.get("model") |
| 174 | + if isinstance(model_slug, str): |
| 175 | + sanitized["model"] = _normalize_model_slug(model_slug) |
| 176 | + |
| 177 | + # Provider quirks: OpenAI-compatible endpoints vary widely in what they accept when |
| 178 | + # we forward `sampling_args.extra_body`. Treat *all* of extra_body as resume-tolerant |
| 179 | + # for the purposes of manifest conflict detection so users can switch providers |
| 180 | + # without getting blocked by payload drift. |
| 181 | + sampling_args = sanitized.get("sampling_args") |
| 182 | + if isinstance(sampling_args, Mapping): |
| 183 | + updated_sampling_args = dict(sampling_args) |
| 184 | + updated_sampling_args.pop("extra_body", None) |
| 185 | + if updated_sampling_args: |
| 186 | + sanitized["sampling_args"] = updated_sampling_args |
| 187 | + else: |
| 188 | + sanitized.pop("sampling_args", None) |
| 189 | + |
| 190 | + return sanitized |
| 191 | + |
| 192 | + |
| 193 | +def _sampling_extra_body(model_payload: Mapping[str, Any]) -> dict[str, Any] | None: |
| 194 | + sampling_args = model_payload.get("sampling_args") |
| 195 | + if not isinstance(sampling_args, Mapping): |
| 196 | + return None |
| 197 | + extra_body = sampling_args.get("extra_body") |
| 198 | + if not isinstance(extra_body, Mapping): |
| 199 | + return None |
| 200 | + normalized = _normalize_payload(extra_body) |
| 201 | + return normalized or None |
| 202 | + |
| 203 | + |
| 204 | +def _warn_extra_body_change(key: str, existing: Mapping[str, Any], payload: Mapping[str, Any]) -> None: |
| 205 | + existing_extra = _sampling_extra_body(existing) |
| 206 | + payload_extra = _sampling_extra_body(payload) |
| 207 | + if existing_extra is None and payload_extra is None: |
| 208 | + return |
| 209 | + if compute_checksum(existing_extra or {}) == compute_checksum(payload_extra or {}): |
| 210 | + return |
| 211 | + logger.warning( |
| 212 | + "Model '%s' sampling_args.extra_body changed; allowing restart, but providers may reject unknown fields.", |
| 213 | + key, |
| 214 | + ) |
| 215 | + |
| 216 | + |
| 217 | +def _sampling_args_payload(model_payload: Mapping[str, Any]) -> dict[str, Any] | None: |
| 218 | + sampling_args = model_payload.get("sampling_args") |
| 219 | + if not isinstance(sampling_args, Mapping): |
| 220 | + return None |
| 221 | + normalized = _normalize_payload(sampling_args) |
| 222 | + return normalized or None |
| 223 | + |
| 224 | + |
| 225 | +def _warn_sampling_args_change(key: str, existing: Mapping[str, Any], payload: Mapping[str, Any]) -> None: |
| 226 | + existing_sampling = _sampling_args_payload(existing) |
| 227 | + payload_sampling = _sampling_args_payload(payload) |
| 228 | + if existing_sampling is None and payload_sampling is None: |
| 229 | + return |
| 230 | + if compute_checksum(existing_sampling or {}) == compute_checksum(payload_sampling or {}): |
| 231 | + return |
| 232 | + logger.warning( |
| 233 | + "Model '%s' sampling_args changed; allowing restart, but providers may reject unsupported parameters.", |
| 234 | + key, |
| 235 | + ) |
148 | 236 |
|
149 | 237 |
|
150 | 238 | def _effective_sampling_args(entry: ManifestJobEntry, model_payload: Mapping[str, Any]) -> Mapping[str, Any]: |
@@ -247,11 +335,28 @@ def _merge_unique_model_payload( |
247 | 335 | if allow_mismatch: |
248 | 336 | container[key] = payload |
249 | 337 | return |
250 | | - if _sanitize_model_payload(existing) == _sanitize_model_payload(payload): |
| 338 | + sanitized_existing = _sanitize_model_payload(existing) |
| 339 | + sanitized_payload = _sanitize_model_payload(payload) |
| 340 | + if sanitized_existing == sanitized_payload: |
| 341 | + _warn_extra_body_change(key, existing, payload) |
251 | 342 | container[key] = payload |
252 | 343 | return |
253 | | - msg = f"Conflicting model payload for '{key}'." |
254 | | - raise ValueError(msg) |
| 344 | + |
| 345 | + stripped_existing = dict(sanitized_existing) |
| 346 | + stripped_payload = dict(sanitized_payload) |
| 347 | + stripped_existing.pop("sampling_args", None) |
| 348 | + stripped_payload.pop("sampling_args", None) |
| 349 | + if stripped_existing == stripped_payload: |
| 350 | + _warn_sampling_args_change(key, existing, payload) |
| 351 | + _warn_extra_body_change(key, existing, payload) |
| 352 | + container[key] = payload |
| 353 | + return |
| 354 | + |
| 355 | + all_keys = set(sanitized_existing) | set(sanitized_payload) |
| 356 | + diff_keys = sorted(key for key in all_keys if sanitized_existing.get(key) != sanitized_payload.get(key)) |
| 357 | + suffix = f" (conflicting keys: {', '.join(diff_keys)})" if diff_keys else "" |
| 358 | + msg = f"Conflicting model payload for '{key}'{suffix}." |
| 359 | + raise ManifestConflictError(msg) |
255 | 360 |
|
256 | 361 |
|
257 | 362 | def _merge_unique_payload( |
|
0 commit comments