33import sys
44from collections .abc import AsyncIterator
55from datetime import timezone
6+ from typing import Any
67
78import pytest
89from inline_snapshot import snapshot
@@ -247,6 +248,14 @@ def test_all_failed() -> None:
247248 assert exceptions [0 ].body == {'error' : 'test error' }
248249
249250
251+ def add_missing_response_model (spans : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
252+ for span in spans :
253+ attrs = span .setdefault ('attributes' , {})
254+ if 'gen_ai.request.model' in attrs :
255+ attrs .setdefault ('gen_ai.response.model' , attrs ['gen_ai.request.model' ])
256+ return spans
257+
258+
250259@pytest .mark .skipif (not logfire_imports_successful (), reason = 'logfire not installed' )
251260def test_all_failed_instrumented (capfire : CaptureLogfire ) -> None :
252261 fallback_model = FallbackModel (failure_model , failure_model )
@@ -260,7 +269,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None:
260269 assert exceptions [0 ].status_code == 500
261270 assert exceptions [0 ].model_name == 'test-function-model'
262271 assert exceptions [0 ].body == {'error' : 'test error' }
263- assert capfire .exporter .exported_spans_as_dict () == snapshot (
272+ assert add_missing_response_model ( capfire .exporter .exported_spans_as_dict () ) == snapshot (
264273 [
265274 {
266275 'name' : 'chat fallback:function:failure_response:,function:failure_response:' ,
@@ -277,6 +286,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None:
277286 'logfire.span_type' : 'span' ,
278287 'logfire.msg' : 'chat fallback:function:failure_response:,function:failure_response:' ,
279288 'logfire.level_num' : 17 ,
289+ 'gen_ai.response.model' : 'fallback:function:failure_response:,function:failure_response:' ,
280290 },
281291 'events' : [
282292 {
0 commit comments