|
6 | 6 | import re
|
7 | 7 | import time
|
8 | 8 | import uuid
|
9 |
| -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator |
| 9 | +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Iterator |
10 | 10 | from contextlib import asynccontextmanager, suppress
|
11 | 11 | from dataclasses import dataclass, fields, is_dataclass
|
12 | 12 | from datetime import datetime, timezone
|
@@ -70,16 +70,33 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
70 | 70 |
|
71 | 71 | if schema.get('type') == 'object':
|
72 | 72 | return schema
|
73 |
| - elif schema.get('$ref') is not None: |
74 |
| - maybe_result = schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/". |
75 |
| - |
76 |
| - if "'$ref': '#/$defs/" in str(maybe_result): |
77 |
| - return schema # We can't remove the $defs because the schema contains other references |
78 |
| - return maybe_result |
| 73 | + elif ref := schema.get('$ref'): |
| 74 | + prefix = '#/$defs/' |
| 75 | + # Return the referenced schema unless it contains additional nested references. |
| 76 | + if ( |
| 77 | + ref.startswith(prefix) |
| 78 | + and (resolved := schema.get('$defs', {}).get(ref[len(prefix) :])) |
| 79 | + and resolved.get('type') == 'object' |
| 80 | + and not _contains_ref(resolved) |
| 81 | + ): |
| 82 | + return resolved |
| 83 | + return schema |
79 | 84 | else:
|
80 | 85 | raise UserError('Schema must be an object')
|
81 | 86 |
|
82 | 87 |
|
| 88 | +def _contains_ref(obj: JsonSchemaValue | list[JsonSchemaValue]) -> bool: |
| 89 | + """Recursively check if an object contains any $ref keys.""" |
| 90 | + items: Iterable[JsonSchemaValue] |
| 91 | + if isinstance(obj, dict): |
| 92 | + if '$ref' in obj: |
| 93 | + return True |
| 94 | + items = obj.values() |
| 95 | + else: |
| 96 | + items = obj |
| 97 | + return any(isinstance(item, dict | list) and _contains_ref(item) for item in items) # pyright: ignore[reportUnknownArgumentType] |
| 98 | + |
| 99 | + |
83 | 100 | T = TypeVar('T')
|
84 | 101 |
|
85 | 102 |
|
|
0 commit comments