Skip to content

Commit ac0ebb9

Browse files
iProzdOutisLi
authored andcommitted
Update repformers.py
1 parent 982dc55 commit ac0ebb9

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

deepmd/pt/model/descriptor/repformers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,6 @@ def forward(
409409
]:
410410
if comm_dict is None:
411411
assert mapping is not None
412-
assert extended_atype_embd is not None
413412
nframes, nloc, nnei = nlist.shape
414413
nall = extended_coord.view(nframes, -1).shape[1] // 3
415414
atype = extended_atype[:, :nloc]
@@ -433,13 +432,9 @@ def forward(
433432
sw = sw.masked_fill(~nlist_mask, 0.0)
434433

435434
# [nframes, nloc, tebd_dim]
436-
if comm_dict is None:
437-
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
438-
atype_embd = extended_atype_embd[:, :nloc, :]
439-
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
440-
else:
441-
atype_embd = extended_atype_embd
442-
assert isinstance(atype_embd, torch.Tensor) # for jit
435+
assert extended_atype_embd is not None
436+
atype_embd = extended_atype_embd[:, :nloc, :]
437+
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
443438
g1 = self.act(atype_embd)
444439
ng1 = g1.shape[-1]
445440
# nb x nloc x nnei x 1, nb x nloc x nnei x 3

0 commit comments

Comments
 (0)