File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -189,6 +189,10 @@ def unpack_one(to_unpack, unpack_pattern = None):
189189def exclusive_cumsum (t , dim = - 1 ):
190190 return t .cumsum (dim = dim ) - t
191191
192+ @typecheck
193+ def symmetrize (t : Float ['b n n ...' ]) -> Float ['b n n ...' ]:
194+ return t + rearrange (t , 'b i j ... -> b j i ...' )
195+
192196@typecheck
193197def masked_average (
194198 t : Shaped ['...' ],
@@ -3427,8 +3431,7 @@ def forward(
34273431 pairwise_repr = batch_repeat_interleave_pairwise (pairwise_repr , molecule_atom_lens )
34283432 pairwise_repr = pairwise_repr + self .atom_feats_to_pairwise (atom_feats )
34293433
3430- symmetrized_pairwise_repr = pairwise_repr + rearrange (pairwise_repr , 'b i j d -> b j i d' )
3431- logits = self .to_distogram_logits (symmetrized_pairwise_repr )
3434+ logits = self .to_distogram_logits (symmetrize (pairwise_repr ))
34323435
34333436 return logits
34343437
@@ -3586,8 +3589,7 @@ def forward(
35863589
35873590 # to logits
35883591
3589- symmetric_pairwise_repr = pairwise_repr + rearrange (pairwise_repr , 'b i j d -> b j i d' )
3590- pde_logits = self .to_pde_logits (symmetric_pairwise_repr )
3592+ pde_logits = self .to_pde_logits (symmetrize (pairwise_repr ))
35913593
35923594 plddt_logits = self .to_plddt_logits (single_repr )
35933595 resolved_logits = self .to_resolved_logits (single_repr )
You can’t perform that action at this time.
0 commit comments