diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index ab98389426..6ada3543a2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -441,19 +441,34 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: optimizer_state_dict = None if resuming: log.info(f"Resuming from {resume_model}.") - state_dict = torch.load( - resume_model, map_location=DEVICE, weights_only=True - ) - if "model" in state_dict: - optimizer_state_dict = ( - state_dict["optimizer"] if finetune_model is None else None + # Load model parameters based on file extension + if resume_model.endswith(".pt"): + # Load checkpoint file (.pt) + state_dict = torch.load( + resume_model, map_location=DEVICE, weights_only=True + ) + if "model" in state_dict: + optimizer_state_dict = ( + state_dict["optimizer"] if finetune_model is None else None + ) + state_dict = state_dict["model"] + self.start_step = ( + state_dict["_extra_state"]["train_infos"]["step"] + if self.restart_training + else 0 + ) + elif resume_model.endswith(".pth"): + # Load frozen model (.pth) - no optimizer state or step info available + jit_model = torch.jit.load(resume_model, map_location=DEVICE) + state_dict = jit_model.state_dict() + # For .pth files, we cannot load optimizer state or step info + optimizer_state_dict = None + self.start_step = 0 + else: + raise RuntimeError( + "The resume model provided must be a checkpoint file with a .pt extension " + "or a frozen model with a .pth extension" ) - state_dict = state_dict["model"] - self.start_step = ( - state_dict["_extra_state"]["train_infos"]["step"] - if self.restart_training - else 0 - ) if self.rank == 0: if force_load: input_keys = list(state_dict.keys()) @@ -483,11 +498,27 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: new_state_dict = {} target_state_dict = self.wrapper.state_dict() # pretrained_model - pretrained_model = get_model_for_wrapper( - state_dict["_extra_state"]["model_params"] - ) + if resume_model.endswith(".pt"): + # For .pt files, get model params from _extra_state + pretrained_model_params = state_dict["_extra_state"][ + "model_params" + ] + elif resume_model.endswith(".pth"): + # For .pth files, the model params were already extracted in get_finetune_rules + # We can reconstruct them from the current wrapper's model_params + pretrained_model_params = self.wrapper.get_extra_state()[ + "model_params" + ] + else: + raise RuntimeError( + "Unsupported finetune model format. Expected .pt or .pth file." + ) + + pretrained_model = get_model_for_wrapper(pretrained_model_params) pretrained_model_wrapper = ModelWrapper(pretrained_model) - pretrained_model_wrapper.load_state_dict(state_dict) + pretrained_model_wrapper.load_state_dict( + state_dict, strict=not resume_model.endswith(".pth") + ) # update type related params for model_key in self.model_keys: finetune_rule_single = self.finetune_links[model_key] @@ -571,7 +602,9 @@ def collect_single_finetune_params( "_extra_state" ] - self.wrapper.load_state_dict(state_dict) + self.wrapper.load_state_dict( + state_dict, strict=not resume_model.endswith(".pth") + ) # change bias for fine-tuning if finetune_model is not None: diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 0e86c9aa6c..f4fb614c8b 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import json import logging from copy import ( deepcopy, @@ -148,10 +149,27 @@ def get_finetune_rules( Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs. """ multi_task = "model_dict" in model_config - state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True) - if "model" in state_dict: - state_dict = state_dict["model"] - last_model_params = state_dict["_extra_state"]["model_params"] + + # Load model parameters based on file extension + if finetune_model.endswith(".pt"): + # Load checkpoint file (.pt) + state_dict = torch.load( + finetune_model, map_location=env.DEVICE, weights_only=True + ) + if "model" in state_dict: + state_dict = state_dict["model"] + last_model_params = state_dict["_extra_state"]["model_params"] + elif finetune_model.endswith(".pth"): + # Load frozen model (.pth) + jit_model = torch.jit.load(finetune_model, map_location=env.DEVICE) + model_params_string = jit_model.get_model_def_script() + last_model_params = json.loads(model_params_string) + else: + raise RuntimeError( + "The finetune model provided must be a checkpoint file with a .pt extension " + "or a frozen model with a .pth extension" + ) + finetune_from_multi_task = "model_dict" in last_model_params finetune_links = {} if not multi_task: diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index 8998392830..69ccca6afc 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -367,5 +367,80 @@ def setUp(self) -> None: self.testkey = None +class TestFinetuneFromPthModel(unittest.TestCase): + """Test that finetuning from .pth (frozen) models works correctly.""" + + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.data_file = [str(Path(__file__).parent / "water/data/single")] + self.config["training"]["training_data"]["systems"] = self.data_file + self.config["training"]["validation_data"]["systems"] = self.data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def test_finetune_from_pth_model(self) -> None: + """Test that get_finetune_rules works with .pth (frozen) models.""" + # Train initial model + trainer = get_trainer(self.config) + trainer.run() + + # Freeze model to .pth + from deepmd.pt.entrypoints.main import ( + freeze, + ) + + freeze(model="model.pt", output="test_model.pth") + + # Test finetuning from .pth model + finetune_config = deepcopy(self.config) + finetune_config["model"], finetune_links = get_finetune_rules( + "test_model.pth", # This is a .pth file + finetune_config["model"], + change_model_params=True, + ) + + # Verify the finetune rules were created successfully + self.assertIsNotNone(finetune_links) + self.assertIn("Default", finetune_links) + self.assertIsInstance(finetune_config["model"], dict) + self.assertIn("type_map", finetune_config["model"]) + + self.tearDown() + + def test_finetune_from_pt_model_still_works(self) -> None: + """Test that original .pt finetuning still works after our changes.""" + # Train initial model + trainer = get_trainer(self.config) + trainer.run() + + # Test finetuning from .pt model (original functionality) + finetune_config = deepcopy(self.config) + finetune_config["model"], finetune_links = get_finetune_rules( + "model.pt", # This is a .pt file + finetune_config["model"], + change_model_params=True, + ) + + # Verify the finetune rules were created successfully + self.assertIsNotNone(finetune_links) + self.assertIn("Default", finetune_links) + self.assertIsInstance(finetune_config["model"], dict) + self.assertIn("type_map", finetune_config["model"]) + + self.tearDown() + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and (f.endswith(".pt") or f.endswith(".pth")): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + if __name__ == "__main__": unittest.main()