Skip to content

Commit eeabe19

Browse files
author
root
committed
run auto-parallel
1 parent 35783bb commit eeabe19

File tree

12 files changed

+149
-138
lines changed

12 files changed

+149
-138
lines changed

deepmd/pd/loss/ener.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from deepmd.utils.version import (
2222
check_version_compatibility,
2323
)
24+
import paddle.distributed as dist
2425

2526

2627
def custom_huber_loss(predictions, targets, delta=1.0):
@@ -205,7 +206,11 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
205206
find_energy = label.get("find_energy", 0.0)
206207
pref_e = pref_e * find_energy
207208
if not self.use_l1_all:
208-
l2_ener_loss = paddle.mean(paddle.square(energy_pred - energy_label))
209+
210+
tmp = energy_pred - energy_label
211+
logit = dist.reshard(tmp, tmp.process_mesh, [dist.Replicate()])
212+
213+
l2_ener_loss = paddle.mean(paddle.square(logit))
209214
if not self.inference:
210215
more_loss["l2_ener_loss"] = self.display_if_exist(
211216
l2_ener_loss.detach(), find_energy
@@ -258,7 +263,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
258263
force_pred = model_pred["force"]
259264
force_label = label["force"]
260265
diff_f = (force_label - force_pred).reshape([-1])
261-
266+
diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])
267+
262268
if self.relative_f is not None:
263269
force_label_3 = force_label.reshape([-1, 3])
264270
norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f
@@ -354,6 +360,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
354360
find_virial = label.get("find_virial", 0.0)
355361
pref_v = pref_v * find_virial
356362
diff_v = label["virial"] - model_pred["virial"].reshape([-1, 9])
363+
diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()])
357364
l2_virial_loss = paddle.mean(paddle.square(diff_v))
358365
if not self.inference:
359366
more_loss["l2_virial_loss"] = self.display_if_exist(

deepmd/pd/train/training.py

Lines changed: 117 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,14 @@ def get_opt_param(params):
164164

165165
def get_data_loader(_training_data, _validation_data, _training_params):
166166
def get_dataloader_and_buffer(_data, _params):
167-
_sampler = get_sampler_from_params(_data, _params)
168-
if _sampler is None:
169-
log.warning(
170-
"Sampler not specified!"
171-
) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
167+
# _sampler = get_sampler_from_params(_data, _params)
168+
# if _sampler is None:
169+
# log.warning(
170+
# "Sampler not specified!"
171+
# ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
172172
_dataloader = DataLoader(
173173
_data,
174-
batch_sampler=paddle.io.BatchSampler(
175-
sampler=_sampler,
176-
drop_last=False,
177-
),
174+
batch_size=1,
178175
num_workers=NUM_WORKERS
179176
if dist.is_available()
180177
else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
@@ -325,17 +322,18 @@ def get_lr(lr_params):
325322
self.validation_data,
326323
self.valid_numb_batch,
327324
) = get_data_loader(training_data, validation_data, training_params)
328-
training_data.print_summary(
329-
"training",
330-
to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights),
331-
)
332-
if validation_data is not None:
333-
validation_data.print_summary(
334-
"validation",
335-
to_numpy_array(
336-
self.validation_dataloader.batch_sampler.sampler.weights
337-
),
338-
)
325+
# no sampler, do not need print!
326+
# training_data.print_summary(
327+
# "training",
328+
# to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights),
329+
# )
330+
# if validation_data is not None:
331+
# validation_data.print_summary(
332+
# "validation",
333+
# to_numpy_array(
334+
# self.validation_dataloader.batch_sampler.sampler.weights
335+
# ),
336+
# )
339337
else:
340338
(
341339
self.training_dataloader,
@@ -370,27 +368,27 @@ def get_lr(lr_params):
370368
validation_data[model_key],
371369
training_params["data_dict"][model_key],
372370
)
373-
374-
training_data[model_key].print_summary(
375-
f"training in {model_key}",
376-
to_numpy_array(
377-
self.training_dataloader[
378-
model_key
379-
].batch_sampler.sampler.weights
380-
),
381-
)
382-
if (
383-
validation_data is not None
384-
and validation_data[model_key] is not None
385-
):
386-
validation_data[model_key].print_summary(
387-
f"validation in {model_key}",
388-
to_numpy_array(
389-
self.validation_dataloader[
390-
model_key
391-
].batch_sampler.sampler.weights
392-
),
393-
)
371+
# no sampler, do not need print!
372+
# training_data[model_key].print_summary(
373+
# f"training in {model_key}",
374+
# to_numpy_array(
375+
# self.training_dataloader[
376+
# model_key
377+
# ].batch_sampler.sampler.weights
378+
# ),
379+
# )
380+
# if (
381+
# validation_data is not None
382+
# and validation_data[model_key] is not None
383+
# ):
384+
# validation_data[model_key].print_summary(
385+
# f"validation in {model_key}",
386+
# to_numpy_array(
387+
# self.validation_dataloader[
388+
# model_key
389+
# ].batch_sampler.sampler.weights
390+
# ),
391+
# )
394392

395393
# Learning rate
396394
self.warmup_steps = training_params.get("warmup_steps", 0)
@@ -706,7 +704,7 @@ def run(self) -> None:
706704
fout1 = open(record_file, mode="w", buffering=1)
707705
log.info("Start to train %d steps.", self.num_steps)
708706
if dist.is_available() and dist.is_initialized():
709-
log.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}")
707+
log.info(f"xxx Rank: {dist.get_rank()}/{dist.get_world_size()}")
710708
if self.enable_tensorboard:
711709
from tensorboardX import (
712710
SummaryWriter,
@@ -755,50 +753,54 @@ def step(_step_id, task_key="Default") -> None:
755753
if self.world_size > 1
756754
else contextlib.nullcontext
757755
)
756+
757+
# with nvprof_context(enable_profiling, "Forward pass"):
758+
log_dict = {}
759+
760+
input_dict = {
761+
"spin": None,
762+
"fparam": None,
763+
"aparam": None,
764+
}
765+
label_dict = {
766+
"find_box": 1.0,
767+
"find_coord": 1.0,
768+
"find_numb_copy": 0.0,
769+
"find_energy": 1.0,
770+
"find_force": 1.0,
771+
"find_virial": 0.0,
772+
}
773+
for k in ["atype", "box", "coord"]:
774+
input_dict[k] = paddle.load(f"./input_{k}.pd")
775+
for k in ["energy", "force", "natoms", "numb_copy", "virial"]:
776+
label_dict[k] = paddle.load(f"./label_{k}.pd")
777+
778+
for __key in ('coord', 'atype', 'box'):
779+
input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
780+
for __key, _ in label_dict.items():
781+
if isinstance(label_dict[__key], paddle.Tensor):
782+
label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
758783

759-
# with sync_context():
760-
# with nvprof_context(enable_profiling, "Forward pass"):
761-
# model_pred, loss, more_loss = self.wrapper(
762-
# **input_dict,
763-
# cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
764-
# label=label_dict,
765-
# task_key=task_key,
766-
# )
767-
768-
# with nvprof_context(enable_profiling, "Backward pass"):
769-
# loss.backward()
770-
771-
# if self.world_size > 1:
772-
# # fuse + allreduce manually before optimization if use DDP + no_sync
773-
# # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
774-
# hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
775-
776-
with nvprof_context(enable_profiling, "Forward pass"):
777-
for __key in ('coord', 'atype', 'box'):
778-
input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
779-
for __key, _ in label_dict.items():
780-
if isinstance(label_dict[__key], paddle.Tensor):
781-
label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
782-
model_pred, loss, more_loss = self.wrapper(
783-
**input_dict,
784-
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
785-
label=label_dict,
786-
task_key=task_key,
787-
)
784+
model_pred, loss, more_loss = self.wrapper(
785+
**input_dict,
786+
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
787+
label=label_dict,
788+
task_key=task_key,
789+
)
788790

789-
with nvprof_context(enable_profiling, "Backward pass"):
790-
loss.backward()
791+
# with nvprof_context(enable_profiling, "Backward pass"):
792+
loss.backward()
791793

792794
if self.gradient_max_norm > 0.0:
793-
with nvprof_context(enable_profiling, "Gradient clip"):
794-
paddle.nn.utils.clip_grad_norm_(
795-
self.wrapper.parameters(),
796-
self.gradient_max_norm,
797-
error_if_nonfinite=True,
798-
)
795+
# with nvprof_context(enable_profiling, "Gradient clip"):
796+
paddle.nn.utils.clip_grad_norm_(
797+
self.wrapper.parameters(),
798+
self.gradient_max_norm,
799+
error_if_nonfinite=True,
800+
)
799801

800-
with nvprof_context(enable_profiling, "Adam update"):
801-
self.optimizer.step()
802+
# with nvprof_context(enable_profiling, "Adam update"):
803+
self.optimizer.step()
802804
self.scheduler.step()
803805

804806
else:
@@ -856,7 +858,9 @@ def log_loss_valid(_task_key="Default"):
856858

857859
if not self.multi_task:
858860
train_results = log_loss_train(loss, more_loss)
859-
valid_results = log_loss_valid()
861+
# valid_results = log_loss_valid()
862+
# no run valid!
863+
valid_results = None
860864
if self.rank == 0:
861865
log.info(
862866
format_training_message_per_task(
@@ -938,39 +942,39 @@ def log_loss_valid(_task_key="Default"):
938942
):
939943
self.total_train_time += train_time
940944

941-
if fout:
942-
if self.lcurve_should_print_header:
943-
self.print_header(fout, train_results, valid_results)
944-
self.lcurve_should_print_header = False
945-
self.print_on_training(
946-
fout, display_step_id, cur_lr, train_results, valid_results
947-
)
948-
949-
if (
950-
((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
951-
or (_step_id + 1) == self.num_steps
952-
) and (self.rank == 0 or dist.get_rank() == 0):
953-
# Handle the case if rank 0 aborted and re-assigned
954-
self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pd")
955-
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
956-
log.info(f"Saved model to {self.latest_model}")
957-
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
958-
with open("checkpoint", "w") as f:
959-
f.write(str(self.latest_model))
945+
# if fout:
946+
# if self.lcurve_should_print_header:
947+
# self.print_header(fout, train_results, valid_results)
948+
# self.lcurve_should_print_header = False
949+
# self.print_on_training(
950+
# fout, display_step_id, cur_lr, train_results, valid_results
951+
# )
952+
953+
# if (
954+
# ((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
955+
# or (_step_id + 1) == self.num_steps
956+
# ) and (self.rank == 0 or dist.get_rank() == 0):
957+
# # Handle the case if rank 0 aborted and re-assigned
958+
# self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pd")
959+
# self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
960+
# log.info(f"Saved model to {self.latest_model}")
961+
# symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
962+
# with open("checkpoint", "w") as f:
963+
# f.write(str(self.latest_model))
960964

961965
# tensorboard
962-
if self.enable_tensorboard and (
963-
display_step_id % self.tensorboard_freq == 0 or display_step_id == 1
964-
):
965-
writer.add_scalar(f"{task_key}/lr", cur_lr, display_step_id)
966-
writer.add_scalar(f"{task_key}/loss", loss.item(), display_step_id)
967-
for item in more_loss:
968-
writer.add_scalar(
969-
f"{task_key}/{item}", more_loss[item].item(), display_step_id
970-
)
971-
972-
if enable_profiling:
973-
core.nvprof_nvtx_pop()
966+
# if self.enable_tensorboard and (
967+
# display_step_id % self.tensorboard_freq == 0 or display_step_id == 1
968+
# ):
969+
# writer.add_scalar(f"{task_key}/lr", cur_lr, display_step_id)
970+
# writer.add_scalar(f"{task_key}/loss", loss.item(), display_step_id)
971+
# for item in more_loss:
972+
# writer.add_scalar(
973+
# f"{task_key}/{item}", more_loss[item].item(), display_step_id
974+
# )
975+
976+
# if enable_profiling:
977+
# core.nvprof_nvtx_pop()
974978

975979
self.wrapper.train()
976980
self.t0 = time.time()

deepmd/pd/utils/dataloader.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -168,30 +168,30 @@ def construct_dataset(system):
168168
self.batch_sizes = batch_size * np.ones(len(systems), dtype=int)
169169
assert len(self.systems) == len(self.batch_sizes)
170170
for system, batch_size in zip(self.systems, self.batch_sizes):
171-
if dist.is_available() and dist.is_initialized():
172-
system_batch_sampler = DistributedBatchSampler(
173-
system,
174-
shuffle=(
175-
(not (dist.is_available() and dist.is_initialized()))
176-
and shuffle
177-
),
178-
batch_size=int(batch_size),
179-
)
180-
self.sampler_list.append(system_batch_sampler)
181-
else:
182-
system_batch_sampler = BatchSampler(
183-
system,
184-
shuffle=(
185-
(not (dist.is_available() and dist.is_initialized()))
186-
and shuffle
187-
),
188-
batch_size=int(batch_size),
189-
)
190-
self.sampler_list.append(system_batch_sampler)
171+
# if dist.is_available() and dist.is_initialized():
172+
# system_batch_sampler = DistributedBatchSampler(
173+
# system,
174+
# shuffle=(
175+
# (not (dist.is_available() and dist.is_initialized()))
176+
# and shuffle
177+
# ),
178+
# batch_size=int(batch_size),
179+
# )
180+
# self.sampler_list.append(system_batch_sampler)
181+
# else:
182+
# system_batch_sampler = BatchSampler(
183+
# system,
184+
# shuffle=(
185+
# (not (dist.is_available() and dist.is_initialized()))
186+
# and shuffle
187+
# ),
188+
# batch_size=int(batch_size),
189+
# )
190+
# self.sampler_list.append(system_batch_sampler)
191191
system_dataloader = DataLoader(
192192
dataset=system,
193193
num_workers=0, # Should be 0 to avoid too many threads forked
194-
batch_sampler=system_batch_sampler,
194+
batch_size=int(batch_size),
195195
collate_fn=collate_batch,
196196
use_buffer_reader=False,
197197
places=["cpu"],

examples/water/dpa3/input_atype.pd

24.2 KB
Binary file not shown.

examples/water/dpa3/input_box.pd

2.44 KB
Binary file not shown.

examples/water/dpa3/input_coord.pd

144 KB
Binary file not shown.

examples/water/dpa3/input_torch.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@
7575
"../data/data_1",
7676
"../data/data_2"
7777
],
78-
"batch_size": 1,
78+
"batch_size": 32,
7979
"_comment": "that's all"
8080
},
8181
"validation_data": {
8282
"systems": [
8383
"../data/data_3"
8484
],
85-
"batch_size": 1,
85+
"batch_size": 32,
8686
"_comment": "that's all"
8787
},
8888
"numb_steps": 2000,
453 Bytes
Binary file not shown.

examples/water/dpa3/label_force.pd

144 KB
Binary file not shown.
709 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)