From 10806ccbfe18556bc9abb1d07e21627b6082d256 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:01:07 +0000 Subject: [PATCH 1/3] Initial plan From efb1cf274b1e32978e314a7bfe79aa54e4f2995f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:30:36 +0000 Subject: [PATCH 2/3] fix(pt): optimize change-bias for multi-task models to save only selected branch Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pt/entrypoints/main.py | 57 +++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 06a7603cc0..92e0f64301 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import argparse -import copy import json import logging import os @@ -400,7 +399,8 @@ def change_bias( old_state_dict = torch.load( input_file, map_location=env.DEVICE, weights_only=True ) - model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) + # Only copy model_params, not the entire state dict, to avoid memory bloat + model_state_dict = old_state_dict.get("model", old_state_dict) model_params = model_state_dict["_extra_state"]["model_params"] elif input_file.endswith(".pth"): old_model = torch.jit.load(input_file, map_location=env.DEVICE) @@ -495,14 +495,53 @@ def change_bias( output_path = ( output if output is not None else input_file.replace(".pt", "_updated.pt") ) - wrapper = ModelWrapper(model) - if "model" in old_state_dict: - old_state_dict["model"] = wrapper.state_dict() - old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"] + if multi_task: + # For multi-task models, save only the selected branch as a single-head model + single_head_model = updated_model + wrapper = ModelWrapper(single_head_model) + + # Create single-head model parameters + single_head_params = model_params["model_dict"][model_branch].copy() + + # Save only the selected branch with single-head structure + if "model" in old_state_dict: + torch.save( + { + "model": wrapper.state_dict(), + "optimizer": old_state_dict.get("optimizer", {}), + }, + output_path, + ) + # Update the saved model's extra state to reflect single-head parameters + saved_state = torch.load(output_path, weights_only=True) + saved_state["model"]["_extra_state"] = { + "model_params": single_head_params, + "train_infos": model_state_dict["_extra_state"].get( + "train_infos", {"lr": 0.001, "step": 0} + ), + } + torch.save(saved_state, output_path) + else: + state_to_save = wrapper.state_dict() + state_to_save["_extra_state"] = { + "model_params": single_head_params, + "train_infos": model_state_dict["_extra_state"].get( + "train_infos", {"lr": 0.001, "step": 0} + ), + } + torch.save(state_to_save, output_path) else: - old_state_dict = wrapper.state_dict() - old_state_dict["_extra_state"] = model_state_dict["_extra_state"] - torch.save(old_state_dict, output_path) + # For single-task models, keep existing behavior + wrapper = ModelWrapper(model) + if "model" in old_state_dict: + old_state_dict["model"] = wrapper.state_dict() + old_state_dict["model"]["_extra_state"] = model_state_dict[ + "_extra_state" + ] + else: + old_state_dict = wrapper.state_dict() + old_state_dict["_extra_state"] = model_state_dict["_extra_state"] + torch.save(old_state_dict, output_path) else: # for .pth output_path = ( From 5028e761830b074af05b488b9552073a6ae32e37 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 3 Sep 2025 15:06:04 +0000 Subject: [PATCH 3/3] fix(pt): remove optimizer state from multi-task change-bias to reduce file size Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pt/entrypoints/main.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 92e0f64301..ee4733de74 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -505,22 +505,18 @@ def change_bias( # Save only the selected branch with single-head structure if "model" in old_state_dict: - torch.save( - { - "model": wrapper.state_dict(), - "optimizer": old_state_dict.get("optimizer", {}), - }, - output_path, - ) - # Update the saved model's extra state to reflect single-head parameters - saved_state = torch.load(output_path, weights_only=True) - saved_state["model"]["_extra_state"] = { + # For multi-task models, don't include optimizer state to reduce file size + state_to_save = { + "model": wrapper.state_dict(), + } + # Update the model's extra state to reflect single-head parameters + state_to_save["model"]["_extra_state"] = { "model_params": single_head_params, "train_infos": model_state_dict["_extra_state"].get( "train_infos", {"lr": 0.001, "step": 0} ), } - torch.save(saved_state, output_path) + torch.save(state_to_save, output_path) else: state_to_save = wrapper.state_dict() state_to_save["_extra_state"] = {