|
23 | 23 |
|
24 | 24 | from lume.model import LUMEModel |
25 | 25 | from lume.variables import Variable |
| 26 | +import importlib |
26 | 27 |
|
27 | 28 | logger = logging.getLogger(__name__) |
28 | 29 |
|
@@ -801,6 +802,7 @@ def __init__( |
801 | 802 | Initial input values. If None, uses default values from variables. |
802 | 803 | """ |
803 | 804 | self.torch_model = torch_model |
| 805 | + self._initial_inputs = initial_inputs |
804 | 806 |
|
805 | 807 | # Initialize state |
806 | 808 | self._state = {} |
@@ -881,3 +883,213 @@ def supported_variables(self) -> dict[str, Variable]: |
881 | 883 | variables[var.name] = var |
882 | 884 |
|
883 | 885 | return variables |
| 886 | + |
| 887 | + def dump( |
| 888 | + self, |
| 889 | + file: Union[str, os.PathLike], |
| 890 | + base_key: str = "", |
| 891 | + save_models: bool = True, |
| 892 | + save_jit: bool = False, |
| 893 | + ): |
| 894 | + """Saves the LUMETorchModel wrapper configuration to a YAML file. |
| 895 | +
|
| 896 | + This method serializes both the underlying torch_model and the wrapper's |
| 897 | + initial_inputs state. The torch model is saved using its own dump method, |
| 898 | + and the wrapper configuration references the torch model file. |
| 899 | +
|
| 900 | + Parameters |
| 901 | + ---------- |
| 902 | + file : str or os.PathLike |
| 903 | + File path to which the YAML formatted string and corresponding files are saved. |
| 904 | + base_key : str, optional |
| 905 | + Base key for serialization. |
| 906 | + save_models : bool, optional |
| 907 | + Determines whether models are saved to file. |
| 908 | + save_jit : bool, optional |
| 909 | + Determines whether the model is saved as TorchScript. |
| 910 | +
|
| 911 | + """ |
| 912 | + logger.info(f"Dumping LUMETorchModel wrapper configuration to: {file}") |
| 913 | + |
| 914 | + # Get the file prefix for the wrapper |
| 915 | + file_prefix = os.path.splitext(os.path.abspath(file))[0] |
| 916 | + file_dir = os.path.dirname(os.path.abspath(file)) |
| 917 | + filename_prefix = os.path.basename(file_prefix) |
| 918 | + |
| 919 | + # Create a filename for the underlying torch model |
| 920 | + torch_model_filename = f"{filename_prefix}_torch_model.yaml" |
| 921 | + torch_model_filepath = os.path.join(file_dir, torch_model_filename) |
| 922 | + |
| 923 | + # Dump the underlying torch model |
| 924 | + logger.debug(f"Dumping underlying torch model to: {torch_model_filepath}") |
| 925 | + self.torch_model.dump( |
| 926 | + torch_model_filepath, |
| 927 | + base_key=base_key, |
| 928 | + save_models=save_models, |
| 929 | + save_jit=save_jit, |
| 930 | + ) |
| 931 | + |
| 932 | + # Get the fully qualified model class name |
| 933 | + torch_model_class = self.torch_model.__class__ |
| 934 | + torch_model_class_path = ( |
| 935 | + f"{torch_model_class.__module__}.{torch_model_class.__name__}" |
| 936 | + ) |
| 937 | + |
| 938 | + # Create wrapper configuration |
| 939 | + wrapper_config = { |
| 940 | + "model_class": "LUMETorchModel", |
| 941 | + "torch_model_file": torch_model_filename, # Relative path |
| 942 | + "torch_model_class": torch_model_class_path, # Store for easier loading |
| 943 | + "initial_inputs": self._initial_inputs, |
| 944 | + } |
| 945 | + |
| 946 | + # Write wrapper configuration to file |
| 947 | + with open(file, "w") as f: |
| 948 | + yaml.dump(wrapper_config, f, default_flow_style=None, sort_keys=False) |
| 949 | + |
| 950 | + logger.info(f"Successfully dumped LUMETorchModel wrapper to: {file}") |
| 951 | + |
| 952 | + @classmethod |
| 953 | + def from_file(cls, filename: str): |
| 954 | + """Loads a LUMETorchModel from a YAML file. |
| 955 | +
|
| 956 | + Parameters |
| 957 | + ---------- |
| 958 | + filename : str |
| 959 | + Path to the YAML file containing the wrapper configuration. |
| 960 | +
|
| 961 | + Returns |
| 962 | + ------- |
| 963 | + LUMETorchModel |
| 964 | + Instance of the wrapper loaded from the file. |
| 965 | +
|
| 966 | + Raises |
| 967 | + ------ |
| 968 | + OSError |
| 969 | + If the file does not exist. |
| 970 | +
|
| 971 | + """ |
| 972 | + if not os.path.exists(filename): |
| 973 | + raise OSError(f"File {filename} is not found.") |
| 974 | + |
| 975 | + logger.info(f"Loading LUMETorchModel from file: {filename}") |
| 976 | + with open(filename, "r") as file: |
| 977 | + return cls.from_yaml(file, filename) |
| 978 | + |
| 979 | + @classmethod |
| 980 | + def from_yaml( |
| 981 | + cls, yaml_obj: Union[str, TextIOWrapper], config_file: Optional[str] = None |
| 982 | + ): |
| 983 | + """Loads a LUMETorchModel from a YAML string or file object. |
| 984 | +
|
| 985 | + Parameters |
| 986 | + ---------- |
| 987 | + yaml_obj : str or TextIOWrapper |
| 988 | + YAML formatted string or file object containing the wrapper configuration. |
| 989 | + config_file : str, optional |
| 990 | + Path to the configuration file (used to resolve relative paths). |
| 991 | +
|
| 992 | + Returns |
| 993 | + ------- |
| 994 | + LUMETorchModel |
| 995 | + Instance of the wrapper loaded from the YAML configuration. |
| 996 | +
|
| 997 | + Raises |
| 998 | + ------ |
| 999 | + ValueError |
| 1000 | + If the configuration is invalid or torch_model_file is missing. |
| 1001 | +
|
| 1002 | + """ |
| 1003 | + # Load the YAML configuration |
| 1004 | + if isinstance(yaml_obj, TextIOWrapper): |
| 1005 | + logger.debug(f"Reading configuration from file wrapper: {yaml_obj.name}") |
| 1006 | + config = yaml.safe_load(yaml_obj.read()) |
| 1007 | + config_file = os.path.abspath(yaml_obj.name) |
| 1008 | + elif isinstance(yaml_obj, str): |
| 1009 | + if os.path.exists(yaml_obj): |
| 1010 | + logger.debug(f"Loading configuration from file: {yaml_obj}") |
| 1011 | + with open(yaml_obj, "r") as f: |
| 1012 | + config = yaml.safe_load(f.read()) |
| 1013 | + config_file = os.path.abspath(yaml_obj) |
| 1014 | + else: |
| 1015 | + logger.debug("Parsing configuration from YAML string") |
| 1016 | + config = yaml.safe_load(yaml_obj) |
| 1017 | + else: |
| 1018 | + raise ValueError("yaml_obj must be a string or file object") |
| 1019 | + |
| 1020 | + # Validate configuration |
| 1021 | + if "torch_model_file" not in config: |
| 1022 | + raise ValueError( |
| 1023 | + "Configuration must include 'torch_model_file' specifying the path " |
| 1024 | + "to the underlying torch model YAML file." |
| 1025 | + ) |
| 1026 | + |
| 1027 | + # Resolve the torch model file path |
| 1028 | + torch_model_file = config["torch_model_file"] |
| 1029 | + if config_file is not None and not os.path.isabs(torch_model_file): |
| 1030 | + # Resolve relative path |
| 1031 | + config_dir = os.path.dirname(config_file) |
| 1032 | + torch_model_file = os.path.join(config_dir, torch_model_file) |
| 1033 | + |
| 1034 | + logger.debug(f"Loading underlying torch model from: {torch_model_file}") |
| 1035 | + |
| 1036 | + # Try to get the model class path from wrapper config first |
| 1037 | + model_class_name = config.get("torch_model_class") |
| 1038 | + |
| 1039 | + # If not in wrapper config, read from torch model file |
| 1040 | + if model_class_name is None: |
| 1041 | + with open(torch_model_file, "r") as f: |
| 1042 | + torch_config = yaml.safe_load(f.read()) |
| 1043 | + |
| 1044 | + if "model_class" not in torch_config: |
| 1045 | + raise ValueError( |
| 1046 | + f"Torch model configuration in {torch_model_file} must include 'model_class'" |
| 1047 | + ) |
| 1048 | + model_class_name = torch_config["model_class"] |
| 1049 | + |
| 1050 | + # Try to import the model class |
| 1051 | + torch_model_class = None |
| 1052 | + |
| 1053 | + # First check if it's in lume_torch.models |
| 1054 | + try: |
| 1055 | + from lume_torch.models import get_model |
| 1056 | + |
| 1057 | + torch_model_class = get_model(model_class_name) |
| 1058 | + logger.debug( |
| 1059 | + f"Loaded model class from lume_torch.models: {model_class_name}" |
| 1060 | + ) |
| 1061 | + except (KeyError, ImportError) as e: |
| 1062 | + logger.debug( |
| 1063 | + f"Could not load model class {model_class_name} from lume_torch.models: {e}" |
| 1064 | + ) |
| 1065 | + |
| 1066 | + # If not found, try to import from the module path if it's a fully qualified name |
| 1067 | + if torch_model_class is None and "." in model_class_name: |
| 1068 | + try: |
| 1069 | + module_path, class_name = model_class_name.rsplit(".", 1) |
| 1070 | + module = importlib.import_module(module_path) |
| 1071 | + torch_model_class = getattr(module, class_name) |
| 1072 | + logger.debug(f"Loaded model class from module path: {model_class_name}") |
| 1073 | + except (ImportError, AttributeError) as e: |
| 1074 | + logger.debug( |
| 1075 | + f"Could not import from module path {model_class_name}: {e}" |
| 1076 | + ) |
| 1077 | + |
| 1078 | + # If still not found, raise an error |
| 1079 | + if torch_model_class is None: |
| 1080 | + raise ValueError( |
| 1081 | + f"Could not load model class: {model_class_name}. " |
| 1082 | + "The class must be either registered in lume_torch.models or " |
| 1083 | + "accessible via a fully qualified module path." |
| 1084 | + ) |
| 1085 | + |
| 1086 | + # Load the torch model |
| 1087 | + torch_model = torch_model_class.from_file(torch_model_file) |
| 1088 | + |
| 1089 | + # Get initial inputs from config |
| 1090 | + initial_inputs = config.get("initial_inputs", None) |
| 1091 | + |
| 1092 | + logger.info("Successfully loaded LUMETorchModel wrapper") |
| 1093 | + |
| 1094 | + # Create and return the wrapper instance |
| 1095 | + return cls(torch_model=torch_model, initial_inputs=initial_inputs) |
0 commit comments