Skip to content

Commit 57e1f4e

Browse files
authored
feat(pt/tf): add bias changing param/interface (#3933)
Add bias changing param/interface For pt/tf, add `training/change_bias_after_training` to change out bias once after training. For pt, add a separate command `change-bias` to change trained model(pt/pth, multi/single) out bias for specific data: ``` dp change-bias model.pt -s data -n 10 -m change ``` UTs for this feature are still in consideration. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added a new subcommand `change-bias` to adjust model output bias in the PyTorch backend. - Introduced test cases for changing model biases via new test suite. - **Documentation** - Added documentation for the new `change-bias` command, including usage and options. - Updated `index.rst` to include a new entry for `change-bias` under the `Model` section. - **Bug Fixes** - Adjusted data handling in `make_stat_input` to limit processing to a specified number of batches. - **Refactor** - Restructured training configuration to include the parameter `change_bias_after_training`. - Modularized data requirement handling and bias adjustment functions. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent c98185c commit 57e1f4e

File tree

12 files changed

+534
-84
lines changed

12 files changed

+534
-84
lines changed

deepmd/main.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,72 @@ def main_parser() -> argparse.ArgumentParser:
659659
help="treat all types as a single type. Used with se_atten descriptor.",
660660
)
661661

662+
# change_bias
663+
parser_change_bias = subparsers.add_parser(
664+
"change-bias",
665+
parents=[parser_log],
666+
help="(Supported backend: PyTorch) Change model out bias according to the input data.",
667+
formatter_class=RawTextArgumentDefaultsHelpFormatter,
668+
epilog=textwrap.dedent(
669+
"""\
670+
examples:
671+
dp change-bias model.pt -s data -n 10 -m change
672+
"""
673+
),
674+
)
675+
parser_change_bias.add_argument(
676+
"INPUT", help="The input checkpoint file or frozen model file"
677+
)
678+
parser_change_bias_source = parser_change_bias.add_mutually_exclusive_group()
679+
parser_change_bias_source.add_argument(
680+
"-s",
681+
"--system",
682+
default=".",
683+
type=str,
684+
help="The system dir. Recursively detect systems in this directory",
685+
)
686+
parser_change_bias_source.add_argument(
687+
"-b",
688+
"--bias-value",
689+
default=None,
690+
type=float,
691+
nargs="+",
692+
help="The user defined value for each type in the type_map of the model, split with spaces.\n"
693+
"For example, '-93.57 -187.1' for energy bias of two elements. "
694+
"Only supports energy bias changing.",
695+
)
696+
parser_change_bias.add_argument(
697+
"-n",
698+
"--numb-batch",
699+
default=0,
700+
type=int,
701+
help="The number of frames for bias changing in one data system. 0 means all data.",
702+
)
703+
parser_change_bias.add_argument(
704+
"-m",
705+
"--mode",
706+
type=str,
707+
default="change",
708+
choices=["change", "set"],
709+
help="The mode for changing energy bias: \n"
710+
"change (default) : perform predictions using input model on target dataset, "
711+
"and do least square on the errors to obtain the target shift as bias.\n"
712+
"set : directly use the statistic bias in the target dataset.",
713+
)
714+
parser_change_bias.add_argument(
715+
"-o",
716+
"--output",
717+
default=None,
718+
type=str,
719+
help="The model after changing bias.",
720+
)
721+
parser_change_bias.add_argument(
722+
"--model-branch",
723+
type=str,
724+
default=None,
725+
help="Model branch chosen for changing bias if multi-task model.",
726+
)
727+
662728
# --version
663729
parser.add_argument(
664730
"--version", action="version", version=f"DeePMD-kit v{__version__}"
@@ -831,6 +897,7 @@ def main():
831897
"convert-from",
832898
"train-nvnmd",
833899
"show",
900+
"change-bias",
834901
):
835902
deepmd_main = BACKENDS[args.backend]().entry_point_hook
836903
elif args.command is None:

deepmd/pt/entrypoints/main.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import argparse
3+
import copy
34
import json
45
import logging
56
import os
@@ -23,6 +24,9 @@
2324
from deepmd import (
2425
__version__,
2526
)
27+
from deepmd.common import (
28+
expand_sys_str,
29+
)
2630
from deepmd.env import (
2731
GLOBAL_CONFIG,
2832
)
@@ -44,6 +48,9 @@
4448
from deepmd.pt.train import (
4549
training,
4650
)
51+
from deepmd.pt.train.wrapper import (
52+
ModelWrapper,
53+
)
4754
from deepmd.pt.utils import (
4855
env,
4956
)
@@ -59,6 +66,12 @@
5966
from deepmd.pt.utils.multi_task import (
6067
preprocess_shared_params,
6168
)
69+
from deepmd.pt.utils.stat import (
70+
make_stat_input,
71+
)
72+
from deepmd.pt.utils.utils import (
73+
to_numpy_array,
74+
)
6275
from deepmd.utils.argcheck import (
6376
normalize,
6477
)
@@ -376,6 +389,128 @@ def show(FLAGS):
376389
log.info(f"The fitting_net parameter is {fitting_net}")
377390

378391

392+
def change_bias(FLAGS):
393+
if FLAGS.INPUT.endswith(".pt"):
394+
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
395+
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
396+
model_params = model_state_dict["_extra_state"]["model_params"]
397+
elif FLAGS.INPUT.endswith(".pth"):
398+
old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE)
399+
model_params_string = old_model.get_model_def_script()
400+
model_params = json.loads(model_params_string)
401+
old_state_dict = old_model.state_dict()
402+
model_state_dict = old_state_dict
403+
else:
404+
raise RuntimeError(
405+
"The model provided must be a checkpoint file with a .pt extension "
406+
"or a frozen model with a .pth extension"
407+
)
408+
multi_task = "model_dict" in model_params
409+
model_branch = FLAGS.model_branch
410+
bias_adjust_mode = (
411+
"change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic"
412+
)
413+
if multi_task:
414+
assert (
415+
model_branch is not None
416+
), "For multitask model, the model branch must be set!"
417+
assert model_branch in model_params["model_dict"], (
418+
f"For multitask model, the model branch must be in the 'model_dict'! "
419+
f"Available options are : {list(model_params['model_dict'].keys())}."
420+
)
421+
log.info(f"Changing out bias for model {model_branch}.")
422+
model = training.get_model_for_wrapper(model_params)
423+
type_map = (
424+
model_params["type_map"]
425+
if not multi_task
426+
else model_params["model_dict"][model_branch]["type_map"]
427+
)
428+
model_to_change = model if not multi_task else model[model_branch]
429+
if FLAGS.INPUT.endswith(".pt"):
430+
wrapper = ModelWrapper(model)
431+
wrapper.load_state_dict(old_state_dict["model"])
432+
else:
433+
# for .pth
434+
model.load_state_dict(old_state_dict)
435+
436+
if FLAGS.bias_value is not None:
437+
# use user-defined bias
438+
assert model_to_change.model_type in [
439+
"ener"
440+
], "User-defined bias is only available for energy model!"
441+
assert (
442+
len(FLAGS.bias_value) == len(type_map)
443+
), f"The number of elements in the bias should be the same as that in the type_map: {type_map}."
444+
old_bias = model_to_change.get_out_bias()
445+
bias_to_set = torch.tensor(
446+
FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device
447+
).view(old_bias.shape)
448+
model_to_change.set_out_bias(bias_to_set)
449+
log.info(
450+
f"Change output bias of {type_map!s} "
451+
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
452+
f"to {to_numpy_array(bias_to_set).reshape(-1)!s}."
453+
)
454+
updated_model = model_to_change
455+
else:
456+
# calculate bias on given systems
457+
data_systems = process_systems(expand_sys_str(FLAGS.system))
458+
data_single = DpLoaderSet(
459+
data_systems,
460+
1,
461+
type_map,
462+
)
463+
mock_loss = training.get_loss(
464+
{"inference": True}, 1.0, len(type_map), model_to_change
465+
)
466+
data_requirement = mock_loss.label_requirement
467+
data_requirement += training.get_additional_data_requirement(model_to_change)
468+
data_single.add_data_requirement(data_requirement)
469+
nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf")
470+
sampled_data = make_stat_input(
471+
data_single.systems,
472+
data_single.dataloaders,
473+
nbatches,
474+
)
475+
updated_model = training.model_change_out_bias(
476+
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
477+
)
478+
479+
if not multi_task:
480+
model = updated_model
481+
else:
482+
model[model_branch] = updated_model
483+
484+
if FLAGS.INPUT.endswith(".pt"):
485+
output_path = (
486+
FLAGS.output
487+
if FLAGS.output is not None
488+
else FLAGS.INPUT.replace(".pt", "_updated.pt")
489+
)
490+
wrapper = ModelWrapper(model)
491+
if "model" in old_state_dict:
492+
old_state_dict["model"] = wrapper.state_dict()
493+
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]
494+
else:
495+
old_state_dict = wrapper.state_dict()
496+
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]
497+
torch.save(old_state_dict, output_path)
498+
else:
499+
# for .pth
500+
output_path = (
501+
FLAGS.output
502+
if FLAGS.output is not None
503+
else FLAGS.INPUT.replace(".pth", "_updated.pth")
504+
)
505+
model = torch.jit.script(model)
506+
torch.jit.save(
507+
model,
508+
output_path,
509+
{},
510+
)
511+
log.info(f"Saved model to {output_path}")
512+
513+
379514
@record
380515
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
381516
if not isinstance(args, argparse.Namespace):
@@ -400,6 +535,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
400535
freeze(FLAGS)
401536
elif FLAGS.command == "show":
402537
show(FLAGS)
538+
elif FLAGS.command == "change-bias":
539+
change_bias(FLAGS)
403540
else:
404541
raise RuntimeError(f"Invalid command {FLAGS.command}!")
405542

deepmd/pt/loss/ener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
9797
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
9898
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference
99-
self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference
99+
self.has_gf = start_pref_gf != 0.0 and limit_pref_gf != 0.0
100100

101101
self.start_pref_e = start_pref_e
102102
self.limit_pref_e = limit_pref_e

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def init_out_stat(self):
103103
self.register_buffer("out_bias", out_bias_data)
104104
self.register_buffer("out_std", out_std_data)
105105

106+
def set_out_bias(self, out_bias: torch.Tensor) -> None:
107+
self.out_bias = out_bias
108+
106109
def __setitem__(self, key, value):
107110
if key in ["out_bias"]:
108111
self.out_bias = value

deepmd/pt/model/model/make_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def forward_common(
175175
def get_out_bias(self) -> torch.Tensor:
176176
return self.atomic_model.get_out_bias()
177177

178+
def set_out_bias(self, out_bias: torch.Tensor) -> None:
179+
self.atomic_model.set_out_bias(out_bias)
180+
178181
def change_out_bias(
179182
self,
180183
merged,

0 commit comments

Comments
 (0)