Skip to content

Commit a239716

Browse files
xuanyang15copybara-github
authored andcommitted
ADK changes
PiperOrigin-RevId: 813321782
1 parent c51ea0b commit a239716

File tree

2 files changed

+103
-3
lines changed

2 files changed

+103
-3
lines changed

src/google/adk/tools/_gemini_schema_util.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,44 @@ def _sanitize_schema_type(schema: dict[str, Any]) -> dict[str, Any]:
9999
return schema
100100

101101

102+
def _dereference_schema(schema: dict[str, Any]) -> dict[str, Any]:
103+
"""Resolves $ref pointers in a JSON schema."""
104+
105+
defs = schema.get("$defs", {})
106+
107+
def _resolve_refs(sub_schema: Any) -> Any:
108+
if isinstance(sub_schema, dict):
109+
if "$ref" in sub_schema:
110+
ref_key = sub_schema["$ref"].split("/")[-1]
111+
if ref_key in defs:
112+
# Found the reference, replace it with the definition.
113+
resolved = defs[ref_key].copy()
114+
# Merge properties from the reference, allowing overrides.
115+
sub_schema_copy = sub_schema.copy()
116+
del sub_schema_copy["$ref"]
117+
resolved.update(sub_schema_copy)
118+
# Recursively resolve refs in the newly inserted part.
119+
return _resolve_refs(resolved)
120+
else:
121+
# Reference not found, return as is.
122+
return sub_schema
123+
else:
124+
# No $ref, so traverse deeper into the dictionary.
125+
return {key: _resolve_refs(value) for key, value in sub_schema.items()}
126+
elif isinstance(sub_schema, list):
127+
# Traverse into lists.
128+
return [_resolve_refs(item) for item in sub_schema]
129+
else:
130+
# Not a dict or list, return as is.
131+
return sub_schema
132+
133+
dereferenced_schema = _resolve_refs(schema)
134+
# Remove the definitions block after resolving.
135+
if "$defs" in dereferenced_schema:
136+
del dereferenced_schema["$defs"]
137+
return dereferenced_schema
138+
139+
102140
def _sanitize_schema_formats_for_gemini(
103141
schema: dict[str, Any],
104142
) -> dict[str, Any]:
@@ -109,7 +147,10 @@ def _sanitize_schema_formats_for_gemini(
109147
"any_of", # 'one_of', 'all_of', 'not' to come
110148
}
111149
snake_case_schema = {}
112-
dict_schema_field_names: tuple[str] = ("properties",) # 'defs' to come
150+
dict_schema_field_names: tuple[str, ...] = (
151+
"properties",
152+
"defs",
153+
)
113154
for field_name, field_value in schema.items():
114155
field_name = _to_snake_case(field_name)
115156
if field_name in schema_field_names:
@@ -151,8 +192,9 @@ def _to_gemini_schema(openapi_schema: dict[str, Any]) -> Schema:
151192
if not isinstance(openapi_schema, dict):
152193
raise TypeError("openapi_schema must be a dictionary")
153194

154-
openapi_schema = _sanitize_schema_formats_for_gemini(openapi_schema)
195+
dereferenced_schema = _dereference_schema(openapi_schema)
196+
sanitized_schema = _sanitize_schema_formats_for_gemini(dereferenced_schema)
155197
return Schema.from_json_schema(
156-
json_schema=_ExtendedJSONSchema.model_validate(openapi_schema),
198+
json_schema=_ExtendedJSONSchema.model_validate(sanitized_schema),
157199
api_option=get_google_llm_variant(),
158200
)

tests/unittests/tools/test_gemini_schema_util.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,64 @@ def test_to_gemini_schema_remove_unrecognized_fields(self):
224224
assert gemini_schema.type == Type.STRING
225225
assert not gemini_schema.format
226226

227+
def test_to_gemini_schema_nested_dict_with_defs_and_ref(self):
228+
"""Test that nested dict with $defs and $refs is converted correctly."""
229+
openapi_schema = {
230+
"$defs": {
231+
"DeviceEnum": {
232+
"enum": ["GLOBAL", "desktop", "mobile"],
233+
"title": "DeviceEnum",
234+
"type": "string",
235+
},
236+
"DomainPayload": {
237+
"properties": {
238+
"adDomain": {
239+
"description": "List of one or many domains.",
240+
"items": {"type": "string"},
241+
"title": "Addomain",
242+
"type": "array",
243+
},
244+
"device": {
245+
"$ref": "#/$defs/DeviceEnum",
246+
"default": "GLOBAL",
247+
"description": (
248+
"Filter by device. All devices are returned by"
249+
" default."
250+
),
251+
},
252+
},
253+
"required": ["adDomain"],
254+
"title": "DomainPayload",
255+
"type": "object",
256+
},
257+
},
258+
"properties": {"payload": {"$ref": "#/$defs/DomainPayload"}},
259+
"required": ["payload"],
260+
"title": "query_domainsArguments",
261+
"type": "object",
262+
}
263+
gemini_schema = _to_gemini_schema(openapi_schema)
264+
assert gemini_schema.type == Type.OBJECT
265+
assert gemini_schema.properties["payload"].type == Type.OBJECT
266+
assert (
267+
gemini_schema.properties["payload"].properties["adDomain"].type
268+
== Type.ARRAY
269+
)
270+
assert (
271+
gemini_schema.properties["payload"].properties["adDomain"].items.type
272+
== Type.STRING
273+
)
274+
assert (
275+
gemini_schema.properties["payload"].properties["device"].type
276+
== Type.STRING
277+
)
278+
assert gemini_schema.properties["payload"].properties["device"].enum == [
279+
"GLOBAL",
280+
"desktop",
281+
"mobile",
282+
]
283+
assert gemini_schema.properties["payload"].required == ["adDomain"]
284+
227285
def test_sanitize_integer_formats(self):
228286
"""Test that int32 and int64 formats are preserved for integer types"""
229287
openapi_schema = {

0 commit comments

Comments
 (0)