@@ -409,7 +409,6 @@ def forward(
409409 ]:
410410 if comm_dict is None :
411411 assert mapping is not None
412- assert extended_atype_embd is not None
413412 nframes , nloc , nnei = nlist .shape
414413 nall = extended_coord .view (nframes , - 1 ).shape [1 ] // 3
415414 atype = extended_atype [:, :nloc ]
@@ -433,13 +432,9 @@ def forward(
433432 sw = sw .masked_fill (~ nlist_mask , 0.0 )
434433
435434 # [nframes, nloc, tebd_dim]
436- if comm_dict is None :
437- assert isinstance (extended_atype_embd , torch .Tensor ) # for jit
438- atype_embd = extended_atype_embd [:, :nloc , :]
439- assert list (atype_embd .shape ) == [nframes , nloc , self .g1_dim ]
440- else :
441- atype_embd = extended_atype_embd
442- assert isinstance (atype_embd , torch .Tensor ) # for jit
435+ assert extended_atype_embd is not None
436+ atype_embd = extended_atype_embd [:, :nloc , :]
437+ assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
443438 g1 = self .act (atype_embd )
444439 ng1 = g1 .shape [- 1 ]
445440 # nb x nloc x nnei x 1, nb x nloc x nnei x 3
0 commit comments