Skip to content

Commit 581bcb6

Browse files
Copilotnjzjzgithub-advanced-security[bot]pre-commit-ci[bot]
authored
feat(tf): implement change-bias command (#4927)
Implements TensorFlow support for the `dp change-bias` command with proper checkpoint handling and variable restoration. This brings the TensorFlow backend to feature parity with the PyTorch implementation. ## Key Features - **Checkpoint file support**: Handles individual checkpoint files (`.ckpt`, `.meta`, `.data`, `.index`) and frozen models (`.pb`) - **Proper variable restoration**: Variables are correctly restored from checkpoints using session initialization before bias modification - **User-defined bias support**: Supports `-b/--bias-value` option with proper validation against model type_map - **Data-based bias calculation**: Leverages existing `change_energy_bias_lower` functionality for automatic bias computation - **Checkpoint preservation**: Saves modified variables to separate checkpoint directory for continued training - **Cross-backend consistency**: Identical CLI interface and functionality as PyTorch backend ## Before vs After **Variable restoration**: - Before: `Change energy bias of ['O', 'H'] from [0. 0.] to [calculated values]` (variables never restored) - After: `Change energy bias of ['O', 'H'] from [-93.57 -187.15] to [-93.60 -187.19]` (proper restoration) **Output**: Creates both updated checkpoint files AND frozen model for continued training **Documentation**: Comprehensive documentation covering both TensorFlow and PyTorch backends with examples and backend-specific details The implementation includes comprehensive test coverage with real model training to validate functionality without mocks. Fixes #4018. <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/deepmodeling/deepmd-kit/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Signed-off-by: Jinzhe Zeng <[email protected]> Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: njzjz <[email protected]> Co-authored-by: Jinzhe Zeng <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent db22802 commit 581bcb6

File tree

7 files changed

+714
-12
lines changed

7 files changed

+714
-12
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ buildcxx/
5151
node_modules/
5252
*.bib.original
5353

54+
# Coverage files
55+
.coverage
56+
.coverage.*
57+
5458
# Test output files (temporary)
5559
test_dp_test/
5660
test_dp_test_*.out
5761
*_detail.out
62+
out.json

deepmd/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,12 +752,13 @@ def main_parser() -> argparse.ArgumentParser:
752752
parser_change_bias = subparsers.add_parser(
753753
"change-bias",
754754
parents=[parser_log],
755-
help="(Supported backend: PyTorch) Change model out bias according to the input data.",
755+
help="Change model out bias according to the input data.",
756756
formatter_class=RawTextArgumentDefaultsHelpFormatter,
757757
epilog=textwrap.dedent(
758758
"""\
759759
examples:
760-
dp change-bias model.pt -s data -n 10 -m change
760+
dp --pt change-bias model.pt -s data -n 10 -m change
761+
dp --tf change-bias model.ckpt -s data -n 10 -m change
761762
"""
762763
),
763764
)

deepmd/tf/entrypoints/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from ..infer.model_devi import (
55
make_model_devi,
66
)
7+
from .change_bias import (
8+
change_bias,
9+
)
710
from .compress import (
811
compress,
912
)
@@ -34,6 +37,7 @@
3437
)
3538

3639
__all__ = [
40+
"change_bias",
3741
"compress",
3842
"convert",
3943
"doc_train_input",

0 commit comments

Comments
 (0)