3535import pickle
3636import re
3737import shutil
38- import site
3938import subprocess
4039import sys
4140import tempfile
6564_unset = object ()
6665
6766
68- SITE_PACKAGES : typing .Final [pathlib .Path ] = pathlib .Path (
69- site .getsitepackages ()[0 ]
70- )
71- MODELS_DEST : typing .Final [pathlib .Path ] = SITE_PACKAGES / "models"
72-
73-
7467@contextlib .contextmanager
7568def silence_asyncio_long_exec_warning ():
7669 def flt (log_record ):
@@ -635,7 +628,7 @@ def get_database_name(cls):
635628 return dbname .lower ()
636629
637630 @classmethod
638- def get_schema_texts (cls ) -> list [ str ] :
631+ def get_combined_schemas (cls ) -> str :
639632 schema_texts : list [str ] = []
640633
641634 # Look at all SCHEMA entries and potentially create multiple
@@ -647,7 +640,11 @@ def get_schema_texts(cls) -> list[str]:
647640 if schema_text := cls .get_schema_text (name ):
648641 schema_texts .append (schema_text )
649642
650- return schema_texts
643+ return "\n \n " .join (st for st in schema_texts )
644+
645+ @classmethod
646+ def is_schema_field (cls , field : str ) -> bool :
647+ return bool (re .match (r"^SCHEMA(?:_(\w+))?" , field ))
651648
652649 @classmethod
653650 def get_schema_text (cls , field : str ) -> str | None :
@@ -678,12 +675,11 @@ def get_schema_text(cls, field: str) -> str | None:
678675 @classmethod
679676 def get_setup_script (cls ):
680677 script = ""
681- schema = "\n \n " .join (st for st in cls .get_schema_texts ())
682678
683679 # Don't wrap the script into a transaction here, so that
684680 # potentially it's easier to stitch multiple such scripts
685681 # together in a fashion similar to what `edb inittestdb` does.
686- script += f"\n START MIGRATION TO {{ { schema } }};"
682+ script += f"\n START MIGRATION TO {{ { cls . get_combined_schemas () } }};"
687683 script += f"\n POPULATE MIGRATION; \n COMMIT MIGRATION;"
688684
689685 if cls .SETUP :
@@ -763,6 +759,7 @@ def adapt_call(cls, result):
763759
764760
765761class BaseModelTestCase (DatabaseTestCase ):
762+ SCHEMA : str
766763 DEFAULT_MODULE = "default"
767764
768765 client : typing .ClassVar [gel .Client ]
@@ -790,7 +787,7 @@ def setUpClass(cls):
790787
791788 cls .tmp_model_dir = tempfile .TemporaryDirectory (** td_kwargs )
792789
793- model_from_file , model_name = cls ._model_info ()
790+ _ , model_name = cls ._model_info ()
794791
795792 if cls .orm_debug :
796793 print (cls .tmp_model_dir .name )
@@ -830,15 +827,6 @@ def setUpClass(cls):
830827 finally :
831828 gen_client .terminate ()
832829
833- if not model_from_file :
834- # This is a direct schema, let's copy it to the site paths
835- site_model_dir = MODELS_DEST / model_name
836- if site_model_dir .exists ():
837- shutil .rmtree (site_model_dir )
838- shutil .copytree (
839- model_output_dir , site_model_dir , dirs_exist_ok = True
840- )
841-
842830 sys .path .insert (0 , cls .tmp_model_dir .name )
843831
844832 import models
0 commit comments