Skip to content
36 changes: 36 additions & 0 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.finetune import (
warn_configuration_mismatch_during_finetune,
)
from deepmd.utils.path import (
DPH5Path,
)
Expand Down Expand Up @@ -117,6 +120,8 @@ def __init__(
training_params = config["training"]
self.multi_task = "model_dict" in model_params
self.finetune_links = finetune_links
# Store model params for finetune warning comparisons
self.model_params = model_params
self.finetune_update_stat = False
self.model_keys = (
list(model_params["model_dict"]) if self.multi_task else ["Default"]
Expand Down Expand Up @@ -512,6 +517,37 @@ def collect_single_finetune_params(
)

# collect model params from the pretrained model
# First check for configuration mismatches and warn if needed
pretrained_model_params = state_dict["_extra_state"]["model_params"]
for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
_model_key_from = finetune_rule_single.get_model_branch()

# Get current model descriptor config
if self.multi_task:
current_descriptor = self.model_params["model_dict"][
model_key
].get("descriptor", {})
else:
current_descriptor = self.model_params.get("descriptor", {})

# Get pretrained model descriptor config
if "model_dict" in pretrained_model_params:
pretrained_descriptor = pretrained_model_params[
"model_dict"
][_model_key_from].get("descriptor", {})
else:
pretrained_descriptor = pretrained_model_params.get(
"descriptor", {}
)

# Warn about configuration mismatches
warn_configuration_mismatch_during_finetune(
current_descriptor,
pretrained_descriptor,
_model_key_from,
)

for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
collect_single_finetune_params(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/pd/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deepmd.utils.finetune import (
FinetuneRuleItem,
warn_descriptor_config_differences,
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -61,6 +62,15 @@ def get_finetune_rule_single(
"descriptor": single_config.get("descriptor", {}).get("trainable", True),
"fitting_net": single_config.get("fitting_net", {}).get("trainable", True),
}

# Warn about descriptor configuration differences before overwriting
if "descriptor" in single_config and "descriptor" in single_config_chosen:
warn_descriptor_config_differences(
single_config["descriptor"],
single_config_chosen["descriptor"],
model_branch_chosen,
)

single_config["descriptor"] = single_config_chosen["descriptor"]
if not new_fitting:
single_config["fitting_net"] = single_config_chosen["fitting_net"]
Expand Down
36 changes: 36 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
DataLoader,
)

from deepmd.utils.finetune import (
warn_configuration_mismatch_during_finetune,
)
from deepmd.utils.path import (
DPH5Path,
)
Expand Down Expand Up @@ -122,6 +125,8 @@ def __init__(
training_params = config["training"]
self.multi_task = "model_dict" in model_params
self.finetune_links = finetune_links
# Store model params for finetune warning comparisons
self.model_params = model_params
self.finetune_update_stat = False
self.model_keys = (
list(model_params["model_dict"]) if self.multi_task else ["Default"]
Expand Down Expand Up @@ -541,6 +546,37 @@ def collect_single_finetune_params(
)

# collect model params from the pretrained model
# First check for configuration mismatches and warn if needed
pretrained_model_params = state_dict["_extra_state"]["model_params"]
for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
_model_key_from = finetune_rule_single.get_model_branch()

# Get current model descriptor config
if self.multi_task:
current_descriptor = self.model_params["model_dict"][
model_key
].get("descriptor", {})
else:
current_descriptor = self.model_params.get("descriptor", {})

# Get pretrained model descriptor config
if "model_dict" in pretrained_model_params:
pretrained_descriptor = pretrained_model_params[
"model_dict"
][_model_key_from].get("descriptor", {})
else:
pretrained_descriptor = pretrained_model_params.get(
"descriptor", {}
)

# Warn about configuration mismatches
warn_configuration_mismatch_during_finetune(
current_descriptor,
pretrained_descriptor,
_model_key_from,
)

for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
collect_single_finetune_params(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from deepmd.utils.finetune import (
FinetuneRuleItem,
warn_descriptor_config_differences,
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,6 +65,15 @@ def get_finetune_rule_single(
"descriptor": single_config.get("descriptor", {}).get("trainable", True),
"fitting_net": single_config.get("fitting_net", {}).get("trainable", True),
}

# Warn about descriptor configuration differences before overwriting
if "descriptor" in single_config and "descriptor" in single_config_chosen:
warn_descriptor_config_differences(
single_config["descriptor"],
single_config_chosen["descriptor"],
model_branch_chosen,
)

single_config["descriptor"] = single_config_chosen["descriptor"]
if not new_fitting:
single_config["fitting_net"] = single_config_chosen["fitting_net"]
Expand Down
179 changes: 179 additions & 0 deletions deepmd/utils/finetune.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,188 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging

from deepmd.utils.argcheck import (
normalize,
)

log = logging.getLogger(__name__)


def warn_descriptor_config_differences(
input_descriptor: dict,
pretrained_descriptor: dict,
model_branch: str = "Default",
) -> None:
"""
Warn about differences between input descriptor config and pretrained model's descriptor config.

This function is used when --use-pretrain-script option is used and input configuration
will be overwritten with the pretrained model's configuration.

Parameters
----------
input_descriptor : dict
Descriptor configuration from input.json
pretrained_descriptor : dict
Descriptor configuration from pretrained model
model_branch : str
Model branch name for logging context
"""
# Normalize both configurations to ensure consistent comparison
# This avoids warnings for parameters that only differ due to default values
try:
# Create minimal configs for normalization with required fields
base_config = {
"model": {
"fitting_net": {"neuron": [240, 240, 240]},
"type_map": ["H", "O"],
},
"training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100},
}

input_config = base_config.copy()
input_config["model"]["descriptor"] = input_descriptor.copy()

pretrained_config = base_config.copy()
pretrained_config["model"]["descriptor"] = pretrained_descriptor.copy()

# Normalize both configurations
normalized_input = normalize(input_config, multi_task=False)["model"][
"descriptor"
]
normalized_pretrained = normalize(pretrained_config, multi_task=False)["model"][
"descriptor"
]

if normalized_input == normalized_pretrained:
return

# Use normalized configs for comparison to show only meaningful differences
input_descriptor = normalized_input
pretrained_descriptor = normalized_pretrained
except Exception:
# If normalization fails, fall back to original comparison
pass

if input_descriptor == pretrained_descriptor:
return

# Collect differences
differences = []

# Check for keys that differ in values
for key in input_descriptor:
if key in pretrained_descriptor:
if input_descriptor[key] != pretrained_descriptor[key]:
differences.append(
f" {key}: {input_descriptor[key]} -> {pretrained_descriptor[key]}"
)
else:
differences.append(f" {key}: {input_descriptor[key]} -> (removed)")

# Check for keys only in pretrained model
for key in pretrained_descriptor:
if key not in input_descriptor:
differences.append(f" {key}: (added) -> {pretrained_descriptor[key]}")

if differences:
log.warning(
f"Descriptor configuration in input.json differs from pretrained model "
f"(branch '{model_branch}'). The input configuration will be overwritten "
f"with the pretrained model's configuration:\n" + "\n".join(differences)
)


def warn_configuration_mismatch_during_finetune(
input_descriptor: dict,
pretrained_descriptor: dict,
model_branch: str = "Default",
) -> None:
"""
Warn about configuration mismatches between input descriptor and pretrained model
when fine-tuning without --use-pretrain-script option.

This function warns when configurations differ and state_dict initialization
will only pick relevant keys from the pretrained model (e.g., first 6 layers
from a 16-layer model).

Parameters
----------
input_descriptor : dict
Descriptor configuration from input.json
pretrained_descriptor : dict
Descriptor configuration from pretrained model
model_branch : str
Model branch name for logging context
"""
# Normalize both configurations to ensure consistent comparison
# This avoids warnings for parameters that only differ due to default values
try:
# Create minimal configs for normalization with required fields
base_config = {
"model": {
"fitting_net": {"neuron": [240, 240, 240]},
"type_map": ["H", "O"],
},
"training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100},
}

input_config = base_config.copy()
input_config["model"]["descriptor"] = input_descriptor.copy()

pretrained_config = base_config.copy()
pretrained_config["model"]["descriptor"] = pretrained_descriptor.copy()

# Normalize both configurations
normalized_input = normalize(input_config, multi_task=False)["model"][
"descriptor"
]
normalized_pretrained = normalize(pretrained_config, multi_task=False)["model"][
"descriptor"
]

if normalized_input == normalized_pretrained:
return

# Use normalized configs for comparison to show only meaningful differences
input_descriptor = normalized_input
pretrained_descriptor = normalized_pretrained
except Exception:
# If normalization fails, fall back to original comparison
pass

if input_descriptor == pretrained_descriptor:
return

# Collect differences
differences = []

# Check for keys that differ in values
for key in input_descriptor:
if key in pretrained_descriptor:
if input_descriptor[key] != pretrained_descriptor[key]:
differences.append(
f" {key}: {input_descriptor[key]} (input) vs {pretrained_descriptor[key]} (pretrained)"
)
else:
differences.append(f" {key}: {input_descriptor[key]} (input only)")

# Check for keys only in pretrained model
for key in pretrained_descriptor:
if key not in input_descriptor:
differences.append(
f" {key}: {pretrained_descriptor[key]} (pretrained only)"
)

if differences:
log.warning(
f"Descriptor configuration mismatch detected between input.json and pretrained model "
f"(branch '{model_branch}'). State dict initialization will only use compatible parameters "
f"from the pretrained model. Mismatched configuration:\n"
+ "\n".join(differences)
)


class FinetuneRuleItem:
def __init__(
self,
Expand Down