Skip to content

Commit b2b9bc4

Browse files
committed
merge commont functions
1 parent fe6cc8c commit b2b9bc4

File tree

7 files changed

+235
-299
lines changed

7 files changed

+235
-299
lines changed

deepmd/dpmodel/utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
save_dp_model,
3737
traverse_model_dict,
3838
)
39+
from .training_utils import (
40+
compute_total_numb_batch,
41+
resolve_model_prob,
42+
resolve_model_prob_from_epochs,
43+
)
3944

4045
__all__ = [
4146
"AtomExcludeMask",
@@ -49,6 +54,7 @@
4954
"aggregate",
5055
"build_multiple_neighbor_list",
5156
"build_neighbor_list",
57+
"compute_total_numb_batch",
5258
"extend_coord_with_ghosts",
5359
"get_graph_index",
5460
"get_multiple_nlist_key",
@@ -60,6 +66,8 @@
6066
"nlist_distinguish_types",
6167
"normalize_coord",
6268
"phys2inter",
69+
"resolve_model_prob",
70+
"resolve_model_prob_from_epochs",
6371
"save_dp_model",
6472
"to_face_distance",
6573
"traverse_model_dict",
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import logging
3+
from collections.abc import (
4+
Iterable,
5+
)
6+
7+
import numpy as np
8+
9+
log = logging.getLogger(__name__)
10+
11+
12+
def compute_total_numb_batch(
13+
numb_batches: Iterable[int],
14+
sampler_weights: np.ndarray,
15+
) -> int:
16+
"""Compute total number of batches considering sampler weights.
17+
18+
Parameters
19+
----------
20+
numb_batches : Iterable[int]
21+
Number of batches for each data system.
22+
sampler_weights : np.ndarray
23+
Sampling weights for each data system.
24+
25+
Returns
26+
-------
27+
int
28+
Total number of batches.
29+
30+
Raises
31+
------
32+
ValueError
33+
If input validation fails.
34+
"""
35+
weights = np.asarray(sampler_weights, dtype=np.float64)
36+
if weights.ndim != 1:
37+
raise ValueError("Sampler weights must be 1D.")
38+
if weights.size == 0:
39+
raise ValueError("Sampler weights are empty.")
40+
if not np.all(np.isfinite(weights)):
41+
raise ValueError("Sampler weights must be finite.")
42+
if np.any(weights < 0.0):
43+
raise ValueError("Sampler weights must be non-negative.")
44+
weight_sum = float(np.sum(weights))
45+
if weight_sum <= 0.0:
46+
raise ValueError("Sampler weights must sum to a positive value.")
47+
probs = weights / weight_sum
48+
nbatches = np.asarray(numb_batches, dtype=np.float64)
49+
if nbatches.ndim != 1:
50+
raise ValueError("Number of batches must be 1D.")
51+
if nbatches.size == 0:
52+
raise ValueError("Number of batches is empty.")
53+
if not np.all(np.isfinite(nbatches)):
54+
raise ValueError("Number of batches must be finite.")
55+
if np.any(nbatches < 0.0):
56+
raise ValueError("Number of batches must be non-negative.")
57+
if nbatches.shape[0] != probs.shape[0]:
58+
raise ValueError("Number of batches and sampler weights must match.")
59+
valid = probs > 0.0
60+
if not np.any(valid):
61+
raise ValueError(
62+
"Sampler probabilities must contain at least one positive entry."
63+
)
64+
return int(np.ceil(np.max(nbatches[valid] / probs[valid])))
65+
66+
67+
def resolve_model_prob(
68+
model_keys: list[str],
69+
model_prob_config: dict[str, float] | None,
70+
model_training_data: dict[str, object],
71+
rank: int = 0,
72+
) -> np.ndarray:
73+
"""Resolve model training probability for multi-task training.
74+
75+
Parameters
76+
----------
77+
model_keys : list[str]
78+
List of model keys.
79+
model_prob_config : dict[str, float] | None
80+
User-specified model probabilities. If None, use data size.
81+
model_training_data : dict[str, object]
82+
Training data for each model.
83+
rank : int, optional
84+
Process rank for distributed training, by default 0.
85+
86+
Returns
87+
-------
88+
np.ndarray
89+
Normalized model probabilities.
90+
91+
Raises
92+
------
93+
ValueError
94+
If input validation fails.
95+
"""
96+
model_prob = np.zeros(len(model_keys), dtype=np.float64)
97+
if model_prob_config:
98+
missing = [k for k in model_keys if k not in model_prob_config]
99+
if missing:
100+
raise ValueError(
101+
f"training.model_prob must specify all tasks; missing: {missing}"
102+
)
103+
for ii, model_key in enumerate(model_keys):
104+
if model_key in model_prob_config:
105+
model_prob[ii] = float(model_prob_config[model_key])
106+
else:
107+
if rank == 0:
108+
log.info(
109+
"training.model_prob is not set or empty; defaulting to the "
110+
"number of systems per task."
111+
)
112+
for ii, model_key in enumerate(model_keys):
113+
model_prob[ii] = float(len(model_training_data[model_key]))
114+
if not np.all(np.isfinite(model_prob)):
115+
raise ValueError("Model prob must be finite.")
116+
if np.any(model_prob < 0.0):
117+
raise ValueError("Model prob must be non-negative.")
118+
sum_prob = float(np.sum(model_prob))
119+
if sum_prob <= 0.0:
120+
raise ValueError("Sum of model prob must be larger than 0!")
121+
return model_prob / sum_prob
122+
123+
124+
def resolve_model_prob_from_epochs(
125+
model_keys: list[str],
126+
num_epoch_dict_config: dict[str, float],
127+
per_task_total: np.ndarray,
128+
) -> tuple[np.ndarray, int, dict[str, float]]:
129+
"""Resolve model probability and training steps from epoch configuration.
130+
131+
Parameters
132+
----------
133+
model_keys : list[str]
134+
List of model keys.
135+
num_epoch_dict_config : dict[str, float]
136+
Target epochs for each task.
137+
per_task_total : np.ndarray
138+
Total batches per task.
139+
140+
Returns
141+
-------
142+
tuple[np.ndarray, int, dict[str, float]]
143+
Model probabilities, total training steps, and per-task steps.
144+
145+
Raises
146+
------
147+
ValueError
148+
If input validation fails.
149+
"""
150+
if not num_epoch_dict_config:
151+
raise ValueError("training.num_epoch_dict must be set for multi-task epochs.")
152+
missing = [k for k in model_keys if k not in num_epoch_dict_config]
153+
if missing:
154+
raise ValueError(
155+
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
156+
)
157+
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
158+
for ii, model_key in enumerate(model_keys):
159+
epoch_value = num_epoch_dict_config[model_key]
160+
if epoch_value is None:
161+
raise ValueError(
162+
f"training.num_epoch_dict['{model_key}'] must be positive."
163+
)
164+
epoch_value = float(epoch_value)
165+
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
166+
raise ValueError(
167+
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
168+
)
169+
epoch_targets[ii] = epoch_value
170+
per_task_total = np.asarray(per_task_total, dtype=np.float64)
171+
if per_task_total.ndim != 1:
172+
raise ValueError("Per-task total batches must be 1D.")
173+
if per_task_total.shape[0] != epoch_targets.shape[0]:
174+
raise ValueError("Per-task totals and epoch targets must match.")
175+
if not np.all(np.isfinite(per_task_total)):
176+
raise ValueError("Per-task total batches must be finite.")
177+
if np.any(per_task_total <= 0.0):
178+
raise ValueError("Per-task total batches must be positive.")
179+
per_task_steps = per_task_total * epoch_targets
180+
total_target_steps = float(np.sum(per_task_steps))
181+
if total_target_steps <= 0.0:
182+
raise ValueError("Sum of target steps must be positive.")
183+
model_prob = per_task_steps / total_target_steps
184+
num_steps = int(np.ceil(total_target_steps))
185+
per_task_steps_map = {
186+
model_key: float(per_task_steps[ii]) for ii, model_key in enumerate(model_keys)
187+
}
188+
return model_prob, num_steps, per_task_steps_map

deepmd/pd/train/training.py

Lines changed: 7 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@
3030
from deepmd.common import (
3131
symlink_prefix_files,
3232
)
33-
from deepmd.dpmodel.utils.learning_rate import (
34-
BaseLR,
33+
from deepmd.dpmodel.utils.learning_rate import BaseLR
34+
from deepmd.dpmodel.utils import (
35+
compute_total_numb_batch,
36+
resolve_model_prob,
37+
resolve_model_prob_from_epochs,
3538
)
3639
from deepmd.loggers.training import (
3740
format_training_message,
@@ -209,114 +212,6 @@ def get_dataloader_and_buffer(_data, _params):
209212
valid_numb_batch,
210213
)
211214

212-
def compute_total_numb_batch(numb_batches, sampler_weights) -> int:
213-
weights = np.asarray(sampler_weights, dtype=np.float64)
214-
if weights.ndim != 1:
215-
raise ValueError("Sampler weights must be 1D.")
216-
if weights.size == 0:
217-
raise ValueError("Sampler weights are empty.")
218-
if not np.all(np.isfinite(weights)):
219-
raise ValueError("Sampler weights must be finite.")
220-
if np.any(weights < 0.0):
221-
raise ValueError("Sampler weights must be non-negative.")
222-
weight_sum = float(np.sum(weights))
223-
if weight_sum <= 0.0:
224-
raise ValueError("Sampler weights must sum to a positive value.")
225-
probs = weights / weight_sum
226-
nbatches = np.asarray(numb_batches, dtype=np.float64)
227-
if nbatches.ndim != 1:
228-
raise ValueError("Number of batches must be 1D.")
229-
if nbatches.size == 0:
230-
raise ValueError("Number of batches is empty.")
231-
if not np.all(np.isfinite(nbatches)):
232-
raise ValueError("Number of batches must be finite.")
233-
if np.any(nbatches < 0.0):
234-
raise ValueError("Number of batches must be non-negative.")
235-
if nbatches.shape[0] != probs.shape[0]:
236-
raise ValueError("Number of batches and sampler weights must match.")
237-
valid = probs > 0.0
238-
if not np.any(valid):
239-
raise ValueError(
240-
"Sampler probabilities must contain at least one positive entry."
241-
)
242-
return int(np.ceil(np.max(nbatches[valid] / probs[valid])))
243-
244-
def resolve_model_prob(
245-
model_keys,
246-
model_prob_config,
247-
model_training_data,
248-
) -> np.ndarray:
249-
model_prob = np.zeros(len(model_keys), dtype=np.float64)
250-
if model_prob_config:
251-
missing = [k for k in model_keys if k not in model_prob_config]
252-
if missing:
253-
raise ValueError(
254-
f"training.model_prob must specify all tasks; missing: {missing}"
255-
)
256-
for ii, model_key in enumerate(model_keys):
257-
if model_key in model_prob_config:
258-
model_prob[ii] = float(model_prob_config[model_key])
259-
else:
260-
for ii, model_key in enumerate(model_keys):
261-
model_prob[ii] = float(len(model_training_data[model_key]))
262-
if not np.all(np.isfinite(model_prob)):
263-
raise ValueError("Model prob must be finite.")
264-
if np.any(model_prob < 0.0):
265-
raise ValueError("Model prob must be non-negative.")
266-
sum_prob = float(np.sum(model_prob))
267-
if sum_prob <= 0.0:
268-
raise ValueError("Sum of model prob must be larger than 0!")
269-
return model_prob / sum_prob
270-
271-
def resolve_model_prob_from_epochs(
272-
model_keys,
273-
num_epoch_dict_config,
274-
per_task_total,
275-
) -> tuple[np.ndarray, int, dict[str, float]]:
276-
if not num_epoch_dict_config:
277-
raise ValueError(
278-
"training.num_epoch_dict must be set for multi-task epochs."
279-
)
280-
missing = [k for k in model_keys if k not in num_epoch_dict_config]
281-
if missing:
282-
raise ValueError(
283-
"training.num_epoch_dict must specify all tasks; "
284-
f"missing: {missing}"
285-
)
286-
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
287-
for ii, model_key in enumerate(model_keys):
288-
epoch_value = num_epoch_dict_config[model_key]
289-
if epoch_value is None:
290-
raise ValueError(
291-
f"training.num_epoch_dict['{model_key}'] must be positive."
292-
)
293-
epoch_value = float(epoch_value)
294-
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
295-
raise ValueError(
296-
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
297-
)
298-
epoch_targets[ii] = epoch_value
299-
per_task_total = np.asarray(per_task_total, dtype=np.float64)
300-
if per_task_total.ndim != 1:
301-
raise ValueError("Per-task total batches must be 1D.")
302-
if per_task_total.shape[0] != epoch_targets.shape[0]:
303-
raise ValueError("Per-task totals and epoch targets must match.")
304-
if not np.all(np.isfinite(per_task_total)):
305-
raise ValueError("Per-task total batches must be finite.")
306-
if np.any(per_task_total <= 0.0):
307-
raise ValueError("Per-task total batches must be positive.")
308-
per_task_steps = per_task_total * epoch_targets
309-
total_target_steps = float(np.sum(per_task_steps))
310-
if total_target_steps <= 0.0:
311-
raise ValueError("Sum of target steps must be positive.")
312-
model_prob = per_task_steps / total_target_steps
313-
num_steps = int(np.ceil(total_target_steps))
314-
per_task_steps_map = {
315-
model_key: float(per_task_steps[ii])
316-
for ii, model_key in enumerate(model_keys)
317-
}
318-
return model_prob, num_steps, per_task_steps_map
319-
320215
def single_model_stat(
321216
_model,
322217
_data_stat_nbatch,
@@ -563,6 +458,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
563458
self.model_keys,
564459
training_params.get("model_prob"),
565460
training_data,
461+
rank=self.rank,
566462
)
567463

568464
# Learning rate
@@ -756,6 +652,7 @@ def single_model_finetune(
756652
self.model_keys,
757653
training_params.get("model_prob"),
758654
training_data,
655+
rank=self.rank,
759656
)
760657

761658
# Multi-task share params

0 commit comments

Comments
 (0)