Skip to content

Commit 95c1377

Browse files
author
xuexixi
committed
improve auto parallel perf
1 parent 35783bb commit 95c1377

File tree

4 files changed

+49
-42
lines changed

4 files changed

+49
-42
lines changed

deepmd/pd/loss/ener.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from deepmd.utils.version import (
2222
check_version_compatibility,
2323
)
24+
import paddle.distributed as dist
2425

2526

2627
def custom_huber_loss(predictions, targets, delta=1.0):
@@ -205,7 +206,11 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
205206
find_energy = label.get("find_energy", 0.0)
206207
pref_e = pref_e * find_energy
207208
if not self.use_l1_all:
208-
l2_ener_loss = paddle.mean(paddle.square(energy_pred - energy_label))
209+
210+
tmp = energy_pred - energy_label
211+
logit = dist.reshard(tmp, tmp.process_mesh, [dist.Replicate()])
212+
213+
l2_ener_loss = paddle.mean(paddle.square(logit))
209214
if not self.inference:
210215
more_loss["l2_ener_loss"] = self.display_if_exist(
211216
l2_ener_loss.detach(), find_energy
@@ -258,7 +263,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
258263
force_pred = model_pred["force"]
259264
force_label = label["force"]
260265
diff_f = (force_label - force_pred).reshape([-1])
261-
266+
diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])
267+
262268
if self.relative_f is not None:
263269
force_label_3 = force_label.reshape([-1, 3])
264270
norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f
@@ -354,6 +360,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
354360
find_virial = label.get("find_virial", 0.0)
355361
pref_v = pref_v * find_virial
356362
diff_v = label["virial"] - model_pred["virial"].reshape([-1, 9])
363+
diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()])
357364
l2_virial_loss = paddle.mean(paddle.square(diff_v))
358365
if not self.inference:
359366
more_loss["l2_virial_loss"] = self.display_if_exist(

deepmd/pd/train/training.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,7 @@ def get_dataloader_and_buffer(_data, _params):
171171
) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
172172
_dataloader = DataLoader(
173173
_data,
174-
batch_sampler=paddle.io.BatchSampler(
175-
sampler=_sampler,
176-
drop_last=False,
177-
),
174+
batch_size=1,
178175
num_workers=NUM_WORKERS
179176
if dist.is_available()
180177
else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
@@ -325,17 +322,18 @@ def get_lr(lr_params):
325322
self.validation_data,
326323
self.valid_numb_batch,
327324
) = get_data_loader(training_data, validation_data, training_params)
328-
training_data.print_summary(
329-
"training",
330-
to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights),
331-
)
332-
if validation_data is not None:
333-
validation_data.print_summary(
334-
"validation",
335-
to_numpy_array(
336-
self.validation_dataloader.batch_sampler.sampler.weights
337-
),
338-
)
325+
# no sampler, do not need print!
326+
# training_data.print_summary(
327+
# "training",
328+
# to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights),
329+
# )
330+
# if validation_data is not None:
331+
# validation_data.print_summary(
332+
# "validation",
333+
# to_numpy_array(
334+
# self.validation_dataloader.batch_sampler.sampler.weights
335+
# ),
336+
# )
339337
else:
340338
(
341339
self.training_dataloader,
@@ -370,27 +368,27 @@ def get_lr(lr_params):
370368
validation_data[model_key],
371369
training_params["data_dict"][model_key],
372370
)
373-
374-
training_data[model_key].print_summary(
375-
f"training in {model_key}",
376-
to_numpy_array(
377-
self.training_dataloader[
378-
model_key
379-
].batch_sampler.sampler.weights
380-
),
381-
)
382-
if (
383-
validation_data is not None
384-
and validation_data[model_key] is not None
385-
):
386-
validation_data[model_key].print_summary(
387-
f"validation in {model_key}",
388-
to_numpy_array(
389-
self.validation_dataloader[
390-
model_key
391-
].batch_sampler.sampler.weights
392-
),
393-
)
371+
# no sampler, do not need print!
372+
# training_data[model_key].print_summary(
373+
# f"training in {model_key}",
374+
# to_numpy_array(
375+
# self.training_dataloader[
376+
# model_key
377+
# ].batch_sampler.sampler.weights
378+
# ),
379+
# )
380+
# if (
381+
# validation_data is not None
382+
# and validation_data[model_key] is not None
383+
# ):
384+
# validation_data[model_key].print_summary(
385+
# f"validation in {model_key}",
386+
# to_numpy_array(
387+
# self.validation_dataloader[
388+
# model_key
389+
# ].batch_sampler.sampler.weights
390+
# ),
391+
# )
394392

395393
# Learning rate
396394
self.warmup_steps = training_params.get("warmup_steps", 0)
@@ -856,7 +854,9 @@ def log_loss_valid(_task_key="Default"):
856854

857855
if not self.multi_task:
858856
train_results = log_loss_train(loss, more_loss)
859-
valid_results = log_loss_valid()
857+
# valid_results = log_loss_valid()
858+
# no run valid!
859+
valid_results = None
860860
if self.rank == 0:
861861
log.info(
862862
format_training_message_per_task(

deepmd/pd/utils/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def construct_dataset(system):
191191
system_dataloader = DataLoader(
192192
dataset=system,
193193
num_workers=0, # Should be 0 to avoid too many threads forked
194-
batch_sampler=system_batch_sampler,
194+
batch_size=int(batch_size),
195195
collate_fn=collate_batch,
196196
use_buffer_reader=False,
197197
places=["cpu"],

examples/water/dpa3/input_torch.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@
7575
"../data/data_1",
7676
"../data/data_2"
7777
],
78-
"batch_size": 1,
78+
"batch_size": 32,
7979
"_comment": "that's all"
8080
},
8181
"validation_data": {
8282
"systems": [
8383
"../data/data_3"
8484
],
85-
"batch_size": 1,
85+
"batch_size": 32,
8686
"_comment": "that's all"
8787
},
8888
"numb_steps": 2000,

0 commit comments

Comments
 (0)