Skip to content

Commit 5c14976

Browse files
authored
Merge pull request #142 from bhardwaj-gopika/main
Serialization
2 parents d01a225 + c294f94 commit 5c14976

File tree

2 files changed

+403
-1
lines changed

2 files changed

+403
-1
lines changed

lume_torch/base.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from lume.model import LUMEModel
2525
from lume.variables import Variable
26+
import importlib
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -801,6 +802,7 @@ def __init__(
801802
Initial input values. If None, uses default values from variables.
802803
"""
803804
self.torch_model = torch_model
805+
self._initial_inputs = initial_inputs
804806

805807
# Initialize state
806808
self._state = {}
@@ -881,3 +883,213 @@ def supported_variables(self) -> dict[str, Variable]:
881883
variables[var.name] = var
882884

883885
return variables
886+
887+
def dump(
888+
self,
889+
file: Union[str, os.PathLike],
890+
base_key: str = "",
891+
save_models: bool = True,
892+
save_jit: bool = False,
893+
):
894+
"""Saves the LUMETorchModel wrapper configuration to a YAML file.
895+
896+
This method serializes both the underlying torch_model and the wrapper's
897+
initial_inputs state. The torch model is saved using its own dump method,
898+
and the wrapper configuration references the torch model file.
899+
900+
Parameters
901+
----------
902+
file : str or os.PathLike
903+
File path to which the YAML formatted string and corresponding files are saved.
904+
base_key : str, optional
905+
Base key for serialization.
906+
save_models : bool, optional
907+
Determines whether models are saved to file.
908+
save_jit : bool, optional
909+
Determines whether the model is saved as TorchScript.
910+
911+
"""
912+
logger.info(f"Dumping LUMETorchModel wrapper configuration to: {file}")
913+
914+
# Get the file prefix for the wrapper
915+
file_prefix = os.path.splitext(os.path.abspath(file))[0]
916+
file_dir = os.path.dirname(os.path.abspath(file))
917+
filename_prefix = os.path.basename(file_prefix)
918+
919+
# Create a filename for the underlying torch model
920+
torch_model_filename = f"{filename_prefix}_torch_model.yaml"
921+
torch_model_filepath = os.path.join(file_dir, torch_model_filename)
922+
923+
# Dump the underlying torch model
924+
logger.debug(f"Dumping underlying torch model to: {torch_model_filepath}")
925+
self.torch_model.dump(
926+
torch_model_filepath,
927+
base_key=base_key,
928+
save_models=save_models,
929+
save_jit=save_jit,
930+
)
931+
932+
# Get the fully qualified model class name
933+
torch_model_class = self.torch_model.__class__
934+
torch_model_class_path = (
935+
f"{torch_model_class.__module__}.{torch_model_class.__name__}"
936+
)
937+
938+
# Create wrapper configuration
939+
wrapper_config = {
940+
"model_class": "LUMETorchModel",
941+
"torch_model_file": torch_model_filename, # Relative path
942+
"torch_model_class": torch_model_class_path, # Store for easier loading
943+
"initial_inputs": self._initial_inputs,
944+
}
945+
946+
# Write wrapper configuration to file
947+
with open(file, "w") as f:
948+
yaml.dump(wrapper_config, f, default_flow_style=None, sort_keys=False)
949+
950+
logger.info(f"Successfully dumped LUMETorchModel wrapper to: {file}")
951+
952+
@classmethod
953+
def from_file(cls, filename: str):
954+
"""Loads a LUMETorchModel from a YAML file.
955+
956+
Parameters
957+
----------
958+
filename : str
959+
Path to the YAML file containing the wrapper configuration.
960+
961+
Returns
962+
-------
963+
LUMETorchModel
964+
Instance of the wrapper loaded from the file.
965+
966+
Raises
967+
------
968+
OSError
969+
If the file does not exist.
970+
971+
"""
972+
if not os.path.exists(filename):
973+
raise OSError(f"File {filename} is not found.")
974+
975+
logger.info(f"Loading LUMETorchModel from file: {filename}")
976+
with open(filename, "r") as file:
977+
return cls.from_yaml(file, filename)
978+
979+
@classmethod
980+
def from_yaml(
981+
cls, yaml_obj: Union[str, TextIOWrapper], config_file: Optional[str] = None
982+
):
983+
"""Loads a LUMETorchModel from a YAML string or file object.
984+
985+
Parameters
986+
----------
987+
yaml_obj : str or TextIOWrapper
988+
YAML formatted string or file object containing the wrapper configuration.
989+
config_file : str, optional
990+
Path to the configuration file (used to resolve relative paths).
991+
992+
Returns
993+
-------
994+
LUMETorchModel
995+
Instance of the wrapper loaded from the YAML configuration.
996+
997+
Raises
998+
------
999+
ValueError
1000+
If the configuration is invalid or torch_model_file is missing.
1001+
1002+
"""
1003+
# Load the YAML configuration
1004+
if isinstance(yaml_obj, TextIOWrapper):
1005+
logger.debug(f"Reading configuration from file wrapper: {yaml_obj.name}")
1006+
config = yaml.safe_load(yaml_obj.read())
1007+
config_file = os.path.abspath(yaml_obj.name)
1008+
elif isinstance(yaml_obj, str):
1009+
if os.path.exists(yaml_obj):
1010+
logger.debug(f"Loading configuration from file: {yaml_obj}")
1011+
with open(yaml_obj, "r") as f:
1012+
config = yaml.safe_load(f.read())
1013+
config_file = os.path.abspath(yaml_obj)
1014+
else:
1015+
logger.debug("Parsing configuration from YAML string")
1016+
config = yaml.safe_load(yaml_obj)
1017+
else:
1018+
raise ValueError("yaml_obj must be a string or file object")
1019+
1020+
# Validate configuration
1021+
if "torch_model_file" not in config:
1022+
raise ValueError(
1023+
"Configuration must include 'torch_model_file' specifying the path "
1024+
"to the underlying torch model YAML file."
1025+
)
1026+
1027+
# Resolve the torch model file path
1028+
torch_model_file = config["torch_model_file"]
1029+
if config_file is not None and not os.path.isabs(torch_model_file):
1030+
# Resolve relative path
1031+
config_dir = os.path.dirname(config_file)
1032+
torch_model_file = os.path.join(config_dir, torch_model_file)
1033+
1034+
logger.debug(f"Loading underlying torch model from: {torch_model_file}")
1035+
1036+
# Try to get the model class path from wrapper config first
1037+
model_class_name = config.get("torch_model_class")
1038+
1039+
# If not in wrapper config, read from torch model file
1040+
if model_class_name is None:
1041+
with open(torch_model_file, "r") as f:
1042+
torch_config = yaml.safe_load(f.read())
1043+
1044+
if "model_class" not in torch_config:
1045+
raise ValueError(
1046+
f"Torch model configuration in {torch_model_file} must include 'model_class'"
1047+
)
1048+
model_class_name = torch_config["model_class"]
1049+
1050+
# Try to import the model class
1051+
torch_model_class = None
1052+
1053+
# First check if it's in lume_torch.models
1054+
try:
1055+
from lume_torch.models import get_model
1056+
1057+
torch_model_class = get_model(model_class_name)
1058+
logger.debug(
1059+
f"Loaded model class from lume_torch.models: {model_class_name}"
1060+
)
1061+
except (KeyError, ImportError) as e:
1062+
logger.debug(
1063+
f"Could not load model class {model_class_name} from lume_torch.models: {e}"
1064+
)
1065+
1066+
# If not found, try to import from the module path if it's a fully qualified name
1067+
if torch_model_class is None and "." in model_class_name:
1068+
try:
1069+
module_path, class_name = model_class_name.rsplit(".", 1)
1070+
module = importlib.import_module(module_path)
1071+
torch_model_class = getattr(module, class_name)
1072+
logger.debug(f"Loaded model class from module path: {model_class_name}")
1073+
except (ImportError, AttributeError) as e:
1074+
logger.debug(
1075+
f"Could not import from module path {model_class_name}: {e}"
1076+
)
1077+
1078+
# If still not found, raise an error
1079+
if torch_model_class is None:
1080+
raise ValueError(
1081+
f"Could not load model class: {model_class_name}. "
1082+
"The class must be either registered in lume_torch.models or "
1083+
"accessible via a fully qualified module path."
1084+
)
1085+
1086+
# Load the torch model
1087+
torch_model = torch_model_class.from_file(torch_model_file)
1088+
1089+
# Get initial inputs from config
1090+
initial_inputs = config.get("initial_inputs", None)
1091+
1092+
logger.info("Successfully loaded LUMETorchModel wrapper")
1093+
1094+
# Create and return the wrapper instance
1095+
return cls(torch_model=torch_model, initial_inputs=initial_inputs)

0 commit comments

Comments
 (0)