@@ -409,6 +409,7 @@ def forward(
409409 ]:
410410 if comm_dict is None :
411411 assert mapping is not None
412+ assert extended_atype_embd is not None
412413 nframes , nloc , nnei = nlist .shape
413414 nall = extended_coord .view (nframes , - 1 ).shape [1 ] // 3
414415 atype = extended_atype [:, :nloc ]
@@ -432,9 +433,13 @@ def forward(
432433 sw = sw .masked_fill (~ nlist_mask , 0.0 )
433434
434435 # [nframes, nloc, tebd_dim]
435- assert extended_atype_embd is not None
436- atype_embd = extended_atype_embd [:, :nloc , :]
437- assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
436+ if comm_dict is None :
437+ assert isinstance (extended_atype_embd , torch .Tensor ) # for jit
438+ atype_embd = extended_atype_embd [:, :nloc , :]
439+ assert list (atype_embd .shape ) == [nframes , nloc , self .g1_dim ]
440+ else :
441+ atype_embd = extended_atype_embd
442+ assert isinstance (atype_embd , torch .Tensor ) # for jit
438443 g1 = self .act (atype_embd )
439444 ng1 = g1 .shape [- 1 ]
440445 # nb x nloc x nnei x 1, nb x nloc x nnei x 3
0 commit comments