Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,19 +441,34 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
optimizer_state_dict = None
if resuming:
log.info(f"Resuming from {resume_model}.")
state_dict = torch.load(
resume_model, map_location=DEVICE, weights_only=True
)
if "model" in state_dict:
optimizer_state_dict = (
state_dict["optimizer"] if finetune_model is None else None
# Load model parameters based on file extension
if resume_model.endswith(".pt"):
# Load checkpoint file (.pt)
state_dict = torch.load(
resume_model, map_location=DEVICE, weights_only=True
)
if "model" in state_dict:
optimizer_state_dict = (
state_dict["optimizer"] if finetune_model is None else None
)
state_dict = state_dict["model"]
self.start_step = (
state_dict["_extra_state"]["train_infos"]["step"]
if self.restart_training
else 0
)
elif resume_model.endswith(".pth"):
# Load frozen model (.pth) - no optimizer state or step info available
jit_model = torch.jit.load(resume_model, map_location=DEVICE)
state_dict = jit_model.state_dict()
# For .pth files, we cannot load optimizer state or step info
optimizer_state_dict = None
self.start_step = 0
else:
raise RuntimeError(
"The resume model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
state_dict = state_dict["model"]
self.start_step = (
state_dict["_extra_state"]["train_infos"]["step"]
if self.restart_training
else 0
)
if self.rank == 0:
if force_load:
input_keys = list(state_dict.keys())
Expand Down Expand Up @@ -483,11 +498,27 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
new_state_dict = {}
target_state_dict = self.wrapper.state_dict()
# pretrained_model
pretrained_model = get_model_for_wrapper(
state_dict["_extra_state"]["model_params"]
)
if resume_model.endswith(".pt"):
# For .pt files, get model params from _extra_state
pretrained_model_params = state_dict["_extra_state"][
"model_params"
]
elif resume_model.endswith(".pth"):
# For .pth files, the model params were already extracted in get_finetune_rules
# We can reconstruct them from the current wrapper's model_params
pretrained_model_params = self.wrapper.get_extra_state()[
"model_params"
]
else:
raise RuntimeError(
"Unsupported finetune model format. Expected .pt or .pth file."
)

pretrained_model = get_model_for_wrapper(pretrained_model_params)
pretrained_model_wrapper = ModelWrapper(pretrained_model)
pretrained_model_wrapper.load_state_dict(state_dict)
pretrained_model_wrapper.load_state_dict(
state_dict, strict=not resume_model.endswith(".pth")
)
# update type related params
for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
Expand Down Expand Up @@ -571,7 +602,9 @@ def collect_single_finetune_params(
"_extra_state"
]

self.wrapper.load_state_dict(state_dict)
self.wrapper.load_state_dict(
state_dict, strict=not resume_model.endswith(".pth")
)

# change bias for fine-tuning
if finetune_model is not None:
Expand Down
26 changes: 22 additions & 4 deletions deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import logging
from copy import (
deepcopy,
Expand Down Expand Up @@ -148,10 +149,27 @@ def get_finetune_rules(
Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs.
"""
multi_task = "model_dict" in model_config
state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
last_model_params = state_dict["_extra_state"]["model_params"]

# Load model parameters based on file extension
if finetune_model.endswith(".pt"):
# Load checkpoint file (.pt)
state_dict = torch.load(
finetune_model, map_location=env.DEVICE, weights_only=True
)
if "model" in state_dict:
state_dict = state_dict["model"]
last_model_params = state_dict["_extra_state"]["model_params"]
elif finetune_model.endswith(".pth"):
# Load frozen model (.pth)
jit_model = torch.jit.load(finetune_model, map_location=env.DEVICE)
model_params_string = jit_model.get_model_def_script()
last_model_params = json.loads(model_params_string)
else:
raise RuntimeError(
"The finetune model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)

finetune_from_multi_task = "model_dict" in last_model_params
finetune_links = {}
if not multi_task:
Expand Down
75 changes: 75 additions & 0 deletions source/tests/pt/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,5 +367,80 @@ def setUp(self) -> None:
self.testkey = None


class TestFinetuneFromPthModel(unittest.TestCase):
"""Test that finetuning from .pth (frozen) models works correctly."""

def setUp(self) -> None:
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.data_file = [str(Path(__file__).parent / "water/data/single")]
self.config["training"]["training_data"]["systems"] = self.data_file
self.config["training"]["validation_data"]["systems"] = self.data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

def test_finetune_from_pth_model(self) -> None:
"""Test that get_finetune_rules works with .pth (frozen) models."""
# Train initial model
trainer = get_trainer(self.config)
trainer.run()

# Freeze model to .pth
from deepmd.pt.entrypoints.main import (
freeze,
)

freeze(model="model.pt", output="test_model.pth")

# Test finetuning from .pth model
finetune_config = deepcopy(self.config)
finetune_config["model"], finetune_links = get_finetune_rules(
"test_model.pth", # This is a .pth file
finetune_config["model"],
change_model_params=True,
)

# Verify the finetune rules were created successfully
self.assertIsNotNone(finetune_links)
self.assertIn("Default", finetune_links)
self.assertIsInstance(finetune_config["model"], dict)
self.assertIn("type_map", finetune_config["model"])

self.tearDown()

def test_finetune_from_pt_model_still_works(self) -> None:
"""Test that original .pt finetuning still works after our changes."""
# Train initial model
trainer = get_trainer(self.config)
trainer.run()

# Test finetuning from .pt model (original functionality)
finetune_config = deepcopy(self.config)
finetune_config["model"], finetune_links = get_finetune_rules(
"model.pt", # This is a .pt file
finetune_config["model"],
change_model_params=True,
)

# Verify the finetune rules were created successfully
self.assertIsNotNone(finetune_links)
self.assertIn("Default", finetune_links)
self.assertIsInstance(finetune_config["model"], dict)
self.assertIn("type_map", finetune_config["model"])

self.tearDown()

def tearDown(self) -> None:
for f in os.listdir("."):
if f.startswith("model") and (f.endswith(".pt") or f.endswith(".pth")):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)


if __name__ == "__main__":
unittest.main()