Skip to content

Commit c98185c

Browse files
authored
feat(pt): support multitask argcheck (#3925)
Note that: 1. docs for multitask args are not supported, may need help. 2. `trim_pattern="_*"` is not supported for repeat dict Argument, may need to update dargs. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced training configuration to support multi-task mode with additional arguments for data configuration. - Updated example configurations to reflect multi-task mode changes. - **Bug Fixes** - Improved logic for updating and normalizing configuration during training regardless of multi-task mode. - **Dependencies** - Upgraded `dargs` package requirement to version `>= 0.4.7`. - **Tests** - Added new test cases for multi-task scenarios in `TestExamples` class. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent e809e64 commit c98185c

File tree

5 files changed

+83
-22
lines changed

5 files changed

+83
-22
lines changed

deepmd/pt/entrypoints/main.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ def train(FLAGS):
245245
)
246246

247247
# argcheck
248-
if not multi_task:
249-
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
250-
config = normalize(config)
248+
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
249+
config = normalize(config, multi_task=multi_task)
251250

252251
# do neighbor stat
253252
min_nbor_dist = None

deepmd/utils/argcheck.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2325,7 +2325,9 @@ def mixed_precision_args(): # ! added by Denghui.
23252325
)
23262326

23272327

2328-
def training_args(): # ! modified by Ziyao: data configuration isolated.
2328+
def training_args(
2329+
multi_task=False,
2330+
): # ! modified by Ziyao: data configuration isolated.
23292331
doc_numb_steps = "Number of training batch. Each training uses one batch of data."
23302332
doc_seed = "The random seed for getting frames from the training data set."
23312333
doc_disp_file = "The file for printing learning curve."
@@ -2364,14 +2366,30 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
23642366
)
23652367
doc_opt_type = "The type of optimizer to use."
23662368
doc_kf_blocksize = "The blocksize for the Kalman filter."
2369+
doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode."
2370+
doc_data_dict = "The multiple definition of the data, used in the multi-task mode."
23672371

23682372
arg_training_data = training_data_args()
23692373
arg_validation_data = validation_data_args()
23702374
mixed_precision_data = mixed_precision_args()
23712375

2372-
args = [
2376+
data_args = [
23732377
arg_training_data,
23742378
arg_validation_data,
2379+
Argument(
2380+
"stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file
2381+
),
2382+
]
2383+
args = (
2384+
data_args
2385+
if not multi_task
2386+
else [
2387+
Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob),
2388+
Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict),
2389+
]
2390+
)
2391+
2392+
args += [
23752393
mixed_precision_data,
23762394
Argument(
23772395
"numb_steps", int, optional=False, doc=doc_numb_steps, alias=["stop_batch"]
@@ -2438,9 +2456,6 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
24382456
optional=True,
24392457
doc=doc_only_pt_supported + doc_gradient_max_norm,
24402458
),
2441-
Argument(
2442-
"stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file
2443-
),
24442459
]
24452460
variants = [
24462461
Variant(
@@ -2472,6 +2487,34 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
24722487
return Argument("training", dict, args, variants, doc=doc_training)
24732488

24742489

2490+
def multi_model_args():
2491+
model_dict = model_args()
2492+
model_dict.name = "model_dict"
2493+
model_dict.repeat = True
2494+
model_dict.doc = (
2495+
"The multiple definition of the model, used in the multi-task mode."
2496+
)
2497+
doc_shared_dict = "The definition of the shared parameters used in the `model_dict` within multi-task mode."
2498+
return Argument(
2499+
"model",
2500+
dict,
2501+
[
2502+
model_dict,
2503+
Argument(
2504+
"shared_dict", dict, optional=True, default={}, doc=doc_shared_dict
2505+
),
2506+
],
2507+
)
2508+
2509+
2510+
def multi_loss_args():
2511+
loss_dict = loss_args()
2512+
loss_dict.name = "loss_dict"
2513+
loss_dict.repeat = True
2514+
loss_dict.doc = "The multiple definition of the loss, used in the multi-task mode."
2515+
return loss_dict
2516+
2517+
24752518
def make_index(keys):
24762519
ret = []
24772520
for ii in keys:
@@ -2502,14 +2545,23 @@ def gen_json(**kwargs):
25022545
)
25032546

25042547

2505-
def gen_args(**kwargs) -> List[Argument]:
2506-
return [
2507-
model_args(),
2508-
learning_rate_args(),
2509-
loss_args(),
2510-
training_args(),
2511-
nvnmd_args(),
2512-
]
2548+
def gen_args(multi_task=False) -> List[Argument]:
2549+
if not multi_task:
2550+
return [
2551+
model_args(),
2552+
learning_rate_args(),
2553+
loss_args(),
2554+
training_args(multi_task=multi_task),
2555+
nvnmd_args(),
2556+
]
2557+
else:
2558+
return [
2559+
multi_model_args(),
2560+
learning_rate_args(),
2561+
multi_loss_args(),
2562+
training_args(multi_task=multi_task),
2563+
nvnmd_args(),
2564+
]
25132565

25142566

25152567
def gen_json_schema() -> str:
@@ -2524,8 +2576,8 @@ def gen_json_schema() -> str:
25242576
return json.dumps(generate_json_schema(arg))
25252577

25262578

2527-
def normalize(data):
2528-
base = Argument("base", dict, gen_args())
2579+
def normalize(data, multi_task=False):
2580+
base = Argument("base", dict, gen_args(multi_task=multi_task))
25292581
data = base.normalize_value(data, trim_pattern="_*")
25302582
base.check_value(data, strict=True)
25312583

examples/water_multi_task/pytorch_example/input_torch.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
"_comment": "that's all"
6868
},
6969
"loss_dict": {
70-
"_comment": " that's all",
7170
"water_1": {
7271
"type": "ener",
7372
"start_pref_e": 0.02,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939
'numpy',
4040
'scipy',
4141
'pyyaml',
42-
'dargs >= 0.4.6',
42+
'dargs >= 0.4.7',
4343
'typing_extensions; python_version < "3.8"',
4444
'importlib_metadata>=1.4; python_version < "3.8"',
4545
'h5py',

source/tests/common/test_examples.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
normalize,
1616
)
1717

18+
from ..pt.test_multitask import (
19+
preprocess_shared_params,
20+
)
21+
1822
p_examples = Path(__file__).parent.parent.parent.parent / "examples"
1923

2024
input_files = (
@@ -51,11 +55,18 @@
5155
p_examples / "water" / "dpa2" / "input_torch.json",
5256
)
5357

58+
input_files_multi = (
59+
p_examples / "water_multi_task" / "pytorch_example" / "input_torch.json",
60+
)
61+
5462

5563
class TestExamples(unittest.TestCase):
5664
def test_arguments(self):
57-
for fn in input_files:
65+
for fn in input_files + input_files_multi:
66+
multi_task = fn in input_files_multi
5867
fn = str(fn)
5968
with self.subTest(fn=fn):
6069
jdata = j_loader(fn)
61-
normalize(jdata)
70+
if multi_task:
71+
jdata["model"], _ = preprocess_shared_params(jdata["model"])
72+
normalize(jdata, multi_task=multi_task)

0 commit comments

Comments
 (0)