Skip to content

Commit 6309ccd

Browse files
committed
reset
1 parent 82d9b65 commit 6309ccd

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,12 +757,14 @@ def forward(
757757
nb, nloc, nnei = nlist.shape
758758
nall = node_ebd_ext.shape[1]
759759
node_ebd = node_ebd_ext[:, :nloc, :]
760-
n_edge = int(nlist_mask.sum().item()) if self.use_dynamic_sel else 0
761760
assert (nb, nloc) == node_ebd.shape[:2]
762761
if not self.use_dynamic_sel:
763762
assert (nb, nloc, nnei, 3) == h2.shape
763+
n_edge = None
764764
else:
765-
assert (n_edge, 3) == h2.shape
765+
# n_edge = int(nlist_mask.sum().item())
766+
# assert (n_edge, 3) == h2.shape
767+
n_edge = h2.shape[0]
766768
del a_nlist # may be used in the future
767769

768770
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]

deepmd/pt/model/descriptor/repformers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def forward(
409409
]:
410410
if comm_dict is None:
411411
assert mapping is not None
412+
assert extended_atype_embd is not None
412413
nframes, nloc, nnei = nlist.shape
413414
nall = extended_coord.view(nframes, -1).shape[1] // 3
414415
atype = extended_atype[:, :nloc]
@@ -432,9 +433,13 @@ def forward(
432433
sw = sw.masked_fill(~nlist_mask, 0.0)
433434

434435
# [nframes, nloc, tebd_dim]
435-
assert extended_atype_embd is not None
436-
atype_embd = extended_atype_embd[:, :nloc, :]
437-
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
436+
if comm_dict is None:
437+
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
438+
atype_embd = extended_atype_embd[:, :nloc, :]
439+
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
440+
else:
441+
atype_embd = extended_atype_embd
442+
assert isinstance(atype_embd, torch.Tensor) # for jit
438443
g1 = self.act(atype_embd)
439444
ng1 = g1.shape[-1]
440445
# nb x nloc x nnei x 1, nb x nloc x nnei x 3

0 commit comments

Comments
 (0)