Skip to content

Commit 66b7c67

Browse files
authored
Refresh on models etag mismatch (#8491)
- Send models etag - Refresh models on 412 - This wires `ModelsManager` to `ModelFamily` so we don't mutate it mid-turn
1 parent 13c42a0 commit 66b7c67

File tree

13 files changed

+243
-86
lines changed

13 files changed

+243
-86
lines changed

codex-rs/codex-api/src/common.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub enum ResponseEvent {
5959
summary_index: i64,
6060
},
6161
RateLimits(RateLimitSnapshot),
62+
ModelsEtag(String),
6263
}
6364

6465
#[derive(Debug, Serialize, Clone)]

codex-rs/codex-api/src/endpoint/chat.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ impl Stream for AggregatedStream {
152152
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
153153
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
154154
}
155+
Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => {
156+
return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag))));
157+
}
155158
Poll::Ready(Some(Ok(ResponseEvent::Completed {
156159
response_id,
157160
token_usage,

codex-rs/codex-api/src/endpoint/models.rs

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::provider::Provider;
55
use crate::telemetry::run_with_request_telemetry;
66
use codex_client::HttpTransport;
77
use codex_client::RequestTelemetry;
8+
use codex_protocol::openai_models::ModelInfo;
89
use codex_protocol::openai_models::ModelsResponse;
910
use http::HeaderMap;
1011
use http::Method;
@@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
4142
&self,
4243
client_version: &str,
4344
extra_headers: HeaderMap,
44-
) -> Result<ModelsResponse, ApiError> {
45+
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
4546
let builder = || {
4647
let mut req = self.provider.build_request(Method::GET, self.path());
4748
req.headers.extend(extra_headers.clone());
@@ -66,17 +67,15 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
6667
.and_then(|value| value.to_str().ok())
6768
.map(ToString::to_string);
6869

69-
let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
70+
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
7071
.map_err(|e| {
7172
ApiError::Stream(format!(
7273
"failed to decode models response: {e}; body: {}",
7374
String::from_utf8_lossy(&resp.body)
7475
))
7576
})?;
7677

77-
let etag = header_etag.unwrap_or(etag);
78-
79-
Ok(ModelsResponse { models, etag })
78+
Ok((models, header_etag))
8079
}
8180
}
8281

@@ -102,16 +101,15 @@ mod tests {
102101
struct CapturingTransport {
103102
last_request: Arc<Mutex<Option<Request>>>,
104103
body: Arc<ModelsResponse>,
104+
etag: Option<String>,
105105
}
106106

107107
impl Default for CapturingTransport {
108108
fn default() -> Self {
109109
Self {
110110
last_request: Arc::new(Mutex::new(None)),
111-
body: Arc::new(ModelsResponse {
112-
models: Vec::new(),
113-
etag: String::new(),
114-
}),
111+
body: Arc::new(ModelsResponse { models: Vec::new() }),
112+
etag: None,
115113
}
116114
}
117115
}
@@ -122,8 +120,8 @@ mod tests {
122120
*self.last_request.lock().unwrap() = Some(req);
123121
let body = serde_json::to_vec(&*self.body).unwrap();
124122
let mut headers = HeaderMap::new();
125-
if !self.body.etag.is_empty() {
126-
headers.insert(ETAG, self.body.etag.parse().unwrap());
123+
if let Some(etag) = &self.etag {
124+
headers.insert(ETAG, etag.parse().unwrap());
127125
}
128126
Ok(Response {
129127
status: StatusCode::OK,
@@ -166,14 +164,12 @@ mod tests {
166164

167165
#[tokio::test]
168166
async fn appends_client_version_query() {
169-
let response = ModelsResponse {
170-
models: Vec::new(),
171-
etag: String::new(),
172-
};
167+
let response = ModelsResponse { models: Vec::new() };
173168

174169
let transport = CapturingTransport {
175170
last_request: Arc::new(Mutex::new(None)),
176171
body: Arc::new(response),
172+
etag: None,
177173
};
178174

179175
let client = ModelsClient::new(
@@ -182,12 +178,12 @@ mod tests {
182178
DummyAuth,
183179
);
184180

185-
let result = client
181+
let (models, _) = client
186182
.list_models("0.99.0", HeaderMap::new())
187183
.await
188184
.expect("request should succeed");
189185

190-
assert_eq!(result.models.len(), 0);
186+
assert_eq!(models.len(), 0);
191187

192188
let url = transport
193189
.last_request
@@ -231,12 +227,12 @@ mod tests {
231227
}))
232228
.unwrap(),
233229
],
234-
etag: String::new(),
235230
};
236231

237232
let transport = CapturingTransport {
238233
last_request: Arc::new(Mutex::new(None)),
239234
body: Arc::new(response),
235+
etag: None,
240236
};
241237

242238
let client = ModelsClient::new(
@@ -245,27 +241,25 @@ mod tests {
245241
DummyAuth,
246242
);
247243

248-
let result = client
244+
let (models, _) = client
249245
.list_models("0.99.0", HeaderMap::new())
250246
.await
251247
.expect("request should succeed");
252248

253-
assert_eq!(result.models.len(), 1);
254-
assert_eq!(result.models[0].slug, "gpt-test");
255-
assert_eq!(result.models[0].supported_in_api, true);
256-
assert_eq!(result.models[0].priority, 1);
249+
assert_eq!(models.len(), 1);
250+
assert_eq!(models[0].slug, "gpt-test");
251+
assert_eq!(models[0].supported_in_api, true);
252+
assert_eq!(models[0].priority, 1);
257253
}
258254

259255
#[tokio::test]
260256
async fn list_models_includes_etag() {
261-
let response = ModelsResponse {
262-
models: Vec::new(),
263-
etag: "\"abc\"".to_string(),
264-
};
257+
let response = ModelsResponse { models: Vec::new() };
265258

266259
let transport = CapturingTransport {
267260
last_request: Arc::new(Mutex::new(None)),
268261
body: Arc::new(response),
262+
etag: Some("\"abc\"".to_string()),
269263
};
270264

271265
let client = ModelsClient::new(
@@ -274,12 +268,12 @@ mod tests {
274268
DummyAuth,
275269
);
276270

277-
let result = client
271+
let (models, etag) = client
278272
.list_models("0.1.0", HeaderMap::new())
279273
.await
280274
.expect("request should succeed");
281275

282-
assert_eq!(result.models.len(), 0);
283-
assert_eq!(result.etag, "\"abc\"");
276+
assert_eq!(models.len(), 0);
277+
assert_eq!(etag, Some("\"abc\"".to_string()));
284278
}
285279
}

codex-rs/codex-api/src/sse/responses.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,19 @@ pub fn spawn_response_stream(
5151
telemetry: Option<Arc<dyn SseTelemetry>>,
5252
) -> ResponseStream {
5353
let rate_limits = parse_rate_limit(&stream_response.headers);
54+
let models_etag = stream_response
55+
.headers
56+
.get("X-Models-Etag")
57+
.and_then(|v| v.to_str().ok())
58+
.map(ToString::to_string);
5459
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
5560
tokio::spawn(async move {
5661
if let Some(snapshot) = rate_limits {
5762
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
5863
}
64+
if let Some(etag) = models_etag {
65+
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
66+
}
5967
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
6068
});
6169

codex-rs/codex-api/tests/models_integration.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ async fn models_client_hits_models_endpoint() {
8686
context_window: None,
8787
experimental_supported_tools: Vec::new(),
8888
}],
89-
etag: String::new(),
9089
};
9190

9291
Mock::given(method("GET"))
@@ -102,13 +101,13 @@ async fn models_client_hits_models_endpoint() {
102101
let transport = ReqwestTransport::new(reqwest::Client::new());
103102
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);
104103

105-
let result = client
104+
let (models, _) = client
106105
.list_models("0.1.0", HeaderMap::new())
107106
.await
108107
.expect("models request should succeed");
109108

110-
assert_eq!(result.models.len(), 1);
111-
assert_eq!(result.models[0].slug, "gpt-test");
109+
assert_eq!(models.len(), 1);
110+
assert_eq!(models[0].slug, "gpt-test");
112111

113112
let received = server
114113
.received_requests()

codex-rs/core/src/codex.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ impl Codex {
246246

247247
let config = Arc::new(config);
248248
if config.features.enabled(Feature::RemoteModels)
249-
&& let Err(err) = models_manager.refresh_available_models(&config).await
249+
&& let Err(err) = models_manager
250+
.refresh_available_models_with_cache(&config)
251+
.await
250252
{
251253
error!("failed to refresh available models: {err:?}");
252254
}
@@ -2613,6 +2615,10 @@ async fn try_run_turn(
26132615
// token usage is available to avoid duplicate TokenCount events.
26142616
sess.update_rate_limits(&turn_context, snapshot).await;
26152617
}
2618+
ResponseEvent::ModelsEtag(etag) => {
2619+
// Update internal state with latest models etag
2620+
sess.services.models_manager.refresh_if_new_etag(etag).await;
2621+
}
26162622
ResponseEvent::Completed {
26172623
response_id: _,
26182624
token_usage,
@@ -3140,7 +3146,7 @@ mod tests {
31403146
exec_policy,
31413147
auth_manager: auth_manager.clone(),
31423148
otel_manager: otel_manager.clone(),
3143-
models_manager,
3149+
models_manager: Arc::clone(&models_manager),
31443150
tool_approvals: Mutex::new(ApprovalStore::default()),
31453151
skills_manager,
31463152
};
@@ -3227,7 +3233,7 @@ mod tests {
32273233
exec_policy,
32283234
auth_manager: Arc::clone(&auth_manager),
32293235
otel_manager: otel_manager.clone(),
3230-
models_manager,
3236+
models_manager: Arc::clone(&models_manager),
32313237
tool_approvals: Mutex::new(ApprovalStore::default()),
32323238
skills_manager,
32333239
};

0 commit comments

Comments
 (0)