3030from pydantic import create_model
3131from torch .nn .modules .batchnorm import _BatchNorm
3232
33- from modelopt .torch .opt .config import (
34- ModeloptBaseConfig ,
35- ModeloptField ,
36- get_kwargs_for_create_model_with_rules ,
37- )
33+ from modelopt .torch .opt .config import ModeloptBaseConfig , get_kwargs_for_create_model_with_rules
3834from modelopt .torch .opt .conversion import ApplyModeError , ModelLikeModule
3935from modelopt .torch .opt .mode import (
4036 ConvertEntrypoint ,
5652 stats ,
5753 torch_detach ,
5854 torch_to ,
59- unwrap_model ,
6055)
6156
6257from .algorithms import ConstraintsFunc , get_constraints_func
6358from .conversion import NASModeRegistry
6459from .patch import PatchData , PatchManager , _modelopt_eval_recursion_guard , prep_for_eval
6560from .registry import DMRegistry
66- from .search_space import SearchSpace , generate_search_space
67- from .utils import MODELOPT_BN_CALIB_ITERS , MODELOPT_QUEUE_MAXLEN , get_subnet_config , sample , select
61+ from .search_space import generate_search_space
62+ from .utils import get_subnet_config , sample , select
6863
6964__all__ = [
7065 "AutoNASConfig" ,
7166 "AutoNASModeDescriptor" ,
7267 "AutoNASPatchManager" ,
7368 "EvolveSearcher" ,
74- "ExportConfig" ,
75- "ExportModeDescriptor" ,
7669 "IterativeSearcher" ,
7770 "RandomSearcher" ,
7871 "convert_autonas_searchspace" ,
7972 "convert_searchspace" ,
80- "export_searchspace" ,
8173 "restore_autonas_searchspace" ,
82- "restore_export" ,
8374 "restore_searchspace" ,
8475 "update_autonas_metadata" ,
8576]
8677
78+ # we have two different numbers here since during training it might take longer to stabilize
79+ MODELOPT_QUEUE_MAXLEN = 50 # indicates length of modelopt data queue for BN calib
80+ MODELOPT_BN_CALIB_ITERS = (
81+ 100 # indicates # iters in train mode 'til we trust BN stats without calib
82+ )
83+
8784
8885def _get_ratio_list ():
8986 return (0.5 , 0.67 , 1.0 )
@@ -132,25 +129,6 @@ def _norm_lin_config():
132129)
133130
134131
135- class ExportConfig (ModeloptBaseConfig ):
136- """Configuration for the export mode.
137-
138- This mode is used to export a model after NAS search.
139- """
140-
141- strict : bool = ModeloptField (
142- default = True ,
143- title = "Strict export" ,
144- description = "Enforces that the subnet configuration must exactly match during export." ,
145- )
146-
147- calib : bool = ModeloptField (
148- default = False ,
149- title = "Calibration" ,
150- description = "Whether to calibrate the subnet before exporting." ,
151- )
152-
153-
154132class AutoNASPatchManager (PatchManager ):
155133 """A class to handle the monkey patching of the model for automode."""
156134
@@ -676,48 +654,6 @@ def update_autonas_metadata(
676654 metadata ["subnet_config" ] = get_subnet_config (model )
677655
678656
679- def export_searchspace (model : nn .Module , config : ExportConfig ) -> ConvertReturnType :
680- """Export a subnet configuration of the search space to a regular model."""
681- # sanity check to avoid DP/DDP here in the entrypoint
682- model = unwrap_model (model , raise_error = True )
683-
684- # store config from model if we can find it for a future convert/restore process
685- subnet_config = get_subnet_config (model )
686-
687- # Check for patching and calibration
688- if PatchManager .is_patched (model ):
689- manager = PatchManager .get_manager (model )
690- if config .calib :
691- manager .call_post_eval ()
692- manager .unpatch ()
693-
694- # export model in-place
695- model = SearchSpace (model ).export ()
696-
697- # construct metadata
698- metadata = {
699- "subnet_config" : subnet_config ,
700- }
701-
702- return model , metadata
703-
704-
705- def restore_export (model : nn .Module , config : ExportConfig , metadata : MetadataDict ) -> nn .Module :
706- """Restore & export the subnet configuration of the search space to a regular model."""
707- # select subnet config provided in metadata
708- select (model , metadata ["subnet_config" ], strict = config ["strict" ])
709-
710- # run export
711- model , metadata_new = export_searchspace (model , config )
712-
713- # double check metadata
714- unmatched_keys = compare_dict (metadata , metadata_new )
715- if unmatched_keys :
716- raise ApplyModeError (f"Unmatched metadata={ unmatched_keys } !" )
717-
718- return model
719-
720-
721657@NASModeRegistry .register_mode
722658class AutoNASModeDescriptor (ModeDescriptor ):
723659 """Class to describe the ``"autonas"`` mode.
@@ -738,12 +674,12 @@ def config_class(self) -> type[ModeloptBaseConfig]:
738674 @property
739675 def next_modes (self ) -> set [str ] | None :
740676 """Modes that must immediately follow this mode."""
741- return {"export " , "kd_loss" , "quantize" , "sparse_magnitude" , "sparse_gpt" }
677+ return {"export_nas " , "kd_loss" , "quantize" , "sparse_magnitude" , "sparse_gpt" }
742678
743679 @property
744680 def export_mode (self ) -> str | None :
745681 """The mode that corresponds to the export mode of this mode."""
746- return "export "
682+ return "export_nas "
747683
748684 @property
749685 def search_algorithm (self ) -> type [BaseSearcher ]:
@@ -769,40 +705,3 @@ def update_for_save(self) -> UpdateEntrypoint:
769705 def update_for_new_mode (self ) -> UpdateEntrypoint :
770706 """The mode's entrypoint for updating the models state before new mode."""
771707 return update_autonas_metadata
772-
773-
774- @NASModeRegistry .register_mode
775- class ExportModeDescriptor (ModeDescriptor ):
776- """Class to describe the ``"export"`` mode.
777-
778- The properties of this mode can be inspected via the source code.
779- """
780-
781- @property
782- def name (self ) -> str :
783- """Returns the value (str representation) of the mode."""
784- return "export"
785-
786- @property
787- def config_class (self ) -> type [ModeloptBaseConfig ]:
788- """Specifies the config class for the mode."""
789- return ExportConfig
790-
791- @property
792- def is_export_mode (self ) -> bool :
793- """Whether the mode is an export mode.
794-
795- Returns:
796- True if the mode is an export mode, False otherwise. Defaults to False.
797- """
798- return True
799-
800- @property
801- def convert (self ) -> ConvertEntrypoint :
802- """The mode's entrypoint for converting a model."""
803- return export_searchspace
804-
805- @property
806- def restore (self ) -> RestoreEntrypoint :
807- """The mode's entrypoint for restoring a model."""
808- return restore_export
0 commit comments