@@ -99,6 +99,44 @@ def _sanitize_schema_type(schema: dict[str, Any]) -> dict[str, Any]:
99
99
return schema
100
100
101
101
102
+ def _dereference_schema (schema : dict [str , Any ]) -> dict [str , Any ]:
103
+ """Resolves $ref pointers in a JSON schema."""
104
+
105
+ defs = schema .get ("$defs" , {})
106
+
107
+ def _resolve_refs (sub_schema : Any ) -> Any :
108
+ if isinstance (sub_schema , dict ):
109
+ if "$ref" in sub_schema :
110
+ ref_key = sub_schema ["$ref" ].split ("/" )[- 1 ]
111
+ if ref_key in defs :
112
+ # Found the reference, replace it with the definition.
113
+ resolved = defs [ref_key ].copy ()
114
+ # Merge properties from the reference, allowing overrides.
115
+ sub_schema_copy = sub_schema .copy ()
116
+ del sub_schema_copy ["$ref" ]
117
+ resolved .update (sub_schema_copy )
118
+ # Recursively resolve refs in the newly inserted part.
119
+ return _resolve_refs (resolved )
120
+ else :
121
+ # Reference not found, return as is.
122
+ return sub_schema
123
+ else :
124
+ # No $ref, so traverse deeper into the dictionary.
125
+ return {key : _resolve_refs (value ) for key , value in sub_schema .items ()}
126
+ elif isinstance (sub_schema , list ):
127
+ # Traverse into lists.
128
+ return [_resolve_refs (item ) for item in sub_schema ]
129
+ else :
130
+ # Not a dict or list, return as is.
131
+ return sub_schema
132
+
133
+ dereferenced_schema = _resolve_refs (schema )
134
+ # Remove the definitions block after resolving.
135
+ if "$defs" in dereferenced_schema :
136
+ del dereferenced_schema ["$defs" ]
137
+ return dereferenced_schema
138
+
139
+
102
140
def _sanitize_schema_formats_for_gemini (
103
141
schema : dict [str , Any ],
104
142
) -> dict [str , Any ]:
@@ -109,7 +147,10 @@ def _sanitize_schema_formats_for_gemini(
109
147
"any_of" , # 'one_of', 'all_of', 'not' to come
110
148
}
111
149
snake_case_schema = {}
112
- dict_schema_field_names : tuple [str ] = ("properties" ,) # 'defs' to come
150
+ dict_schema_field_names : tuple [str , ...] = (
151
+ "properties" ,
152
+ "defs" ,
153
+ )
113
154
for field_name , field_value in schema .items ():
114
155
field_name = _to_snake_case (field_name )
115
156
if field_name in schema_field_names :
@@ -151,8 +192,9 @@ def _to_gemini_schema(openapi_schema: dict[str, Any]) -> Schema:
151
192
if not isinstance (openapi_schema , dict ):
152
193
raise TypeError ("openapi_schema must be a dictionary" )
153
194
154
- openapi_schema = _sanitize_schema_formats_for_gemini (openapi_schema )
195
+ dereferenced_schema = _dereference_schema (openapi_schema )
196
+ sanitized_schema = _sanitize_schema_formats_for_gemini (dereferenced_schema )
155
197
return Schema .from_json_schema (
156
- json_schema = _ExtendedJSONSchema .model_validate (openapi_schema ),
198
+ json_schema = _ExtendedJSONSchema .model_validate (sanitized_schema ),
157
199
api_option = get_google_llm_variant (),
158
200
)
0 commit comments