@@ -51,16 +51,25 @@ def _set_connection(conn: BaseDBAsyncClient) -> None:
5151@asynccontextmanager
5252async def tortoise_wrapper (url : str , models : Optional [str ] = None , timeout : int = 60 ) -> AsyncIterator :
5353 """Initialize Tortoise with internal and project models, close connections when done"""
54- modules : Dict [str , Iterable [Union [str , ModuleType ]]] = {'int_models' : ['dipdup.models' ]}
54+ model_modules : Dict [str , Iterable [Union [str , ModuleType ]]] = {
55+ 'int_models' : ['dipdup.models' ],
56+ }
5557 if models :
56- modules ['models' ] = [models ]
58+ if not models .endswith ('.models' ):
59+ models += '.models'
60+ model_modules ['models' ] = [models ]
61+
62+ # NOTE: Must be called before entering Tortoise context
63+ prepare_models (models )
64+
5765 try :
5866 for attempt in range (timeout ):
5967 try :
6068 await Tortoise .init (
6169 db_url = url ,
62- modules = modules ,
70+ modules = model_modules ,
6371 )
72+
6473 # FIXME: Wait for the connection to be ready, required since 0.19.0
6574 conn = get_connection ()
6675 await conn .execute_query ('SELECT 1' )
@@ -93,17 +102,24 @@ def is_model_class(obj: Any) -> bool:
93102 return isinstance (obj , type ) and issubclass (obj , Model ) and obj != Model and not getattr (obj .Meta , 'abstract' , False )
94103
95104
96- def iter_models (package : str ) -> Iterator [Tuple [str , Type [Model ]]]:
105+ def iter_models (package : Optional [ str ] ) -> Iterator [Tuple [str , Type [Model ]]]:
97106 """Iterate over built-in and project's models"""
98- dipdup_models = importlib .import_module ('dipdup.models' )
99- package_models = importlib .import_module (f'{ package } .models' )
107+ if package and not package .endswith ('.models' ):
108+ package += '.models'
109+
110+ modules = [importlib .import_module ('dipdup.models' )]
111+ if package :
112+ modules .append (importlib .import_module (package ))
113+
114+ for models_module in modules :
115+ for attr in dir (models_module ):
116+ if attr .startswith ('_' ):
117+ continue
100118
101- for models in (dipdup_models , package_models ):
102- for attr in dir (models ):
103- model = getattr (models , attr )
104- if is_model_class (model ):
105- app = 'int_models' if models .__name__ == 'dipdup.models' else 'models'
106- yield app , model
119+ attr_value = getattr (models_module , attr )
120+ if is_model_class (attr_value ):
121+ app = 'int_models' if attr_value .__name__ == 'dipdup.models' else 'models'
122+ yield app , attr_value
107123
108124
109125def set_decimal_context (package : str ) -> None :
@@ -205,26 +221,46 @@ async def move_table(conn: BaseDBAsyncClient, name: str, schema: str, new_schema
205221 await conn .execute_script (f'ALTER TABLE { schema } .{ name } SET SCHEMA { new_schema } ' )
206222
207223
208- def prepare_models (package : str ) -> None :
209- for _ , model in iter_models (package ):
210- # NOTE: Generate missing table names before Tortoise does
211- model ._meta .db_table = model ._meta .db_table or pascal_to_snake (model .__name__ )
224+ def prepare_models (package : Optional [str ]) -> None :
225+ """Prepare TortoiseORM models to use with DipDup.
226+ Generate missing table names, validate models, increase decimal precision.
227+ """
228+ from dipdup .models import Model
212229
230+ decimal_context = decimal .getcontext ()
231+ prec = decimal_context .prec
213232
214- def validate_models (package : str ) -> None :
215- """Check project's models for common mistakes"""
216- for _ , model in iter_models (package ):
217- table_name = model ._meta .db_table
233+ for app , model in iter_models (package ):
218234
235+ # NOTE: Enforce our class for user models
236+ if app == 'models' and not issubclass (model , Model ):
237+ raise DatabaseConfigurationError ('Project models must be subclassed from `dipdup.models.Model`' , model )
238+
239+ # NOTE: Generate missing table names before Tortoise does
240+ if not model ._meta .db_table :
241+ model ._meta .db_table = pascal_to_snake (model .__name__ )
242+
243+ # NOTE: Enforce tables in snake_case
244+ table_name = model ._meta .db_table
219245 if table_name != pascal_to_snake (table_name ):
220246 raise DatabaseConfigurationError ('Table name must be in snake_case' , model )
221247
222248 for field in model ._meta .fields_map .values ():
249+ # NOTE: Enforce fields in snake_case
223250 field_name = field .model_field_name
224251
225252 if field_name != pascal_to_snake (field_name ):
226253 raise DatabaseConfigurationError ('Model fields must be in snake_case' , model )
227254
228- # NOTE: Leads to GraphQL issues
229- if field_name == table_name :
230- raise DatabaseConfigurationError ('Model fields must differ from table name' , model )
255+ # NOTE: Increase decimal precision if needed
256+ if isinstance (field , DecimalField ):
257+ prec = max (prec , field .max_digits )
258+
259+ # NOTE: Set new decimal precision
260+ if decimal_context .prec < prec :
261+ _logger .warning ('Decimal context precision has been updated: %s -> %s' , decimal_context .prec , prec )
262+ decimal_context .prec = prec
263+
264+ # NOTE: DefaultContext is used for new threads
265+ decimal .DefaultContext .prec = prec
266+ decimal .setcontext (decimal_context )
0 commit comments