3
3
import sys
4
4
from collections .abc import AsyncIterator
5
5
from datetime import timezone
6
+ from typing import Any
6
7
7
8
import pytest
8
9
from inline_snapshot import snapshot
@@ -247,6 +248,14 @@ def test_all_failed() -> None:
247
248
assert exceptions [0 ].body == {'error' : 'test error' }
248
249
249
250
251
+ def add_missing_response_model (spans : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
252
+ for span in spans :
253
+ attrs = span .setdefault ('attributes' , {})
254
+ if 'gen_ai.request.model' in attrs :
255
+ attrs .setdefault ('gen_ai.response.model' , attrs ['gen_ai.request.model' ])
256
+ return spans
257
+
258
+
250
259
@pytest .mark .skipif (not logfire_imports_successful (), reason = 'logfire not installed' )
251
260
def test_all_failed_instrumented (capfire : CaptureLogfire ) -> None :
252
261
fallback_model = FallbackModel (failure_model , failure_model )
@@ -260,7 +269,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None:
260
269
assert exceptions [0 ].status_code == 500
261
270
assert exceptions [0 ].model_name == 'test-function-model'
262
271
assert exceptions [0 ].body == {'error' : 'test error' }
263
- assert capfire .exporter .exported_spans_as_dict () == snapshot (
272
+ assert add_missing_response_model ( capfire .exporter .exported_spans_as_dict () ) == snapshot (
264
273
[
265
274
{
266
275
'name' : 'chat fallback:function:failure_response:,function:failure_response:' ,
@@ -277,6 +286,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None:
277
286
'logfire.span_type' : 'span' ,
278
287
'logfire.msg' : 'chat fallback:function:failure_response:,function:failure_response:' ,
279
288
'logfire.level_num' : 17 ,
289
+ 'gen_ai.response.model' : 'fallback:function:failure_response:,function:failure_response:' ,
280
290
},
281
291
'events' : [
282
292
{
0 commit comments