1
1
from __future__ import annotations as _annotations
2
2
3
3
from collections .abc import AsyncIterator
4
- from contextlib import AsyncExitStack , asynccontextmanager
4
+ from contextlib import AsyncExitStack , asynccontextmanager , suppress
5
5
from dataclasses import dataclass , field
6
6
from typing import TYPE_CHECKING , Callable
7
7
8
+ from opentelemetry .trace import get_current_span
9
+
10
+ from pydantic_ai .models .instrumented import InstrumentedModel
11
+
8
12
from ..exceptions import FallbackExceptionGroup , ModelHTTPError
9
13
from . import KnownModelName , Model , ModelRequestParameters , StreamedResponse , infer_model
10
14
@@ -40,7 +44,6 @@ def __init__(
40
44
fallback_on: A callable or tuple of exceptions that should trigger a fallback.
41
45
"""
42
46
self .models = [infer_model (default_model ), * [infer_model (m ) for m in fallback_models ]]
43
- self ._model_name = f'FallBackModel[{ ", " .join (model .model_name for model in self .models )} ]'
44
47
45
48
if isinstance (fallback_on , tuple ):
46
49
self ._fallback_on = _default_fallback_condition_factory (fallback_on )
@@ -62,13 +65,19 @@ async def request(
62
65
for model in self .models :
63
66
try :
64
67
response , usage = await model .request (messages , model_settings , model_request_parameters )
65
- response .model_used = model # type: ignore
66
- return response , usage
67
68
except Exception as exc :
68
69
if self ._fallback_on (exc ):
69
70
exceptions .append (exc )
70
71
continue
71
72
raise exc
73
+ else :
74
+ with suppress (Exception ):
75
+ span = get_current_span ()
76
+ if span .is_recording ():
77
+ attributes = getattr (span , 'attributes' , {})
78
+ if attributes .get ('gen_ai.request.model' ) == self .model_name :
79
+ span .set_attributes (InstrumentedModel .model_attributes (model ))
80
+ return response , usage
72
81
73
82
raise FallbackExceptionGroup ('All models from FallbackModel failed' , exceptions )
74
83
@@ -101,12 +110,11 @@ async def request_stream(
101
110
@property
102
111
def model_name (self ) -> str :
103
112
"""The model name."""
104
- return self ._model_name
113
+ return f'fallback: { "," . join ( model . model_name for model in self .models ) } '
105
114
106
115
@property
107
- def system (self ) -> str | None :
108
- """The system / model provider, n/a for fallback models."""
109
- return None
116
+ def system (self ) -> str :
117
+ return f'fallback:{ "," .join (model .system for model in self .models )} '
110
118
111
119
@property
112
120
def base_url (self ) -> str | None :
0 commit comments