1515from concurrent .futures import FIRST_COMPLETED , ThreadPoolExecutor , wait
1616from copy import copy
1717from functools import cache
18+ from importlib import import_module
1819from importlib .metadata import PackageNotFoundError , version
1920from json import JSONDecodeError
2021from math import ceil
21- from typing import TYPE_CHECKING , Generic , TypeVar
22+ from typing import (
23+ TYPE_CHECKING ,
24+ ForwardRef ,
25+ Generic ,
26+ TypeVar ,
27+ get_args ,
28+ )
2229from urllib .parse import quote , urljoin
2330
2431import requests
@@ -65,7 +72,7 @@ class BaseRester(Generic[T]):
6572 """Base client class with core stubs."""
6673
6774 suffix : str = ""
68- document_model : BaseModel = None # type: ignore
75+ document_model : type [ BaseModel ] | None = None
6976 supports_versions : bool = False
7077 primary_key : str = "material_id"
7178
@@ -1070,10 +1077,24 @@ def _convert_to_model(self, data: list[dict]):
10701077
10711078 def _generate_returned_model (self , doc ):
10721079 model_fields = self .document_model .model_fields
1080+
10731081 set_fields = doc .model_fields_set
10741082 unset_fields = [field for field in model_fields if field not in set_fields ]
1083+
1084+ # Update with locals() from external module if needed
1085+ other_vars = {}
1086+ if any (
1087+ isinstance (typ , ForwardRef )
1088+ for field_meta in model_fields .values ()
1089+ for typ in get_args (field_meta .annotation )
1090+ ):
1091+ other_vars = vars (import_module (self .document_model .__module__ ))
1092+
10751093 include_fields = {
1076- name : (model_fields [name ].annotation , model_fields [name ])
1094+ name : (
1095+ model_fields [name ].annotation ,
1096+ model_fields [name ],
1097+ )
10771098 for name in set_fields
10781099 }
10791100
@@ -1085,6 +1106,8 @@ def _generate_returned_model(self, doc):
10851106 fields_not_requested = (list [str ], unset_fields ),
10861107 __base__ = self .document_model ,
10871108 )
1109+ if other_vars :
1110+ data_model .model_rebuild (_types_namespace = other_vars )
10881111
10891112 def new_repr (self ) -> str :
10901113 extra = ",\n " .join (
0 commit comments