Skip to content

Commit c31c9f2

Browse files
committed
cleanup symmetrizing pairwise repr
1 parent 27ec383 commit c31c9f2

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ def unpack_one(to_unpack, unpack_pattern = None):
189189
def exclusive_cumsum(t, dim = -1):
190190
return t.cumsum(dim = dim) - t
191191

192+
@typecheck
193+
def symmetrize(t: Float['b n n ...']) -> Float['b n n ...']:
194+
return t + rearrange(t, 'b i j ... -> b j i ...')
195+
192196
@typecheck
193197
def masked_average(
194198
t: Shaped['...'],
@@ -3427,8 +3431,7 @@ def forward(
34273431
pairwise_repr = batch_repeat_interleave_pairwise(pairwise_repr, molecule_atom_lens)
34283432
pairwise_repr = pairwise_repr + self.atom_feats_to_pairwise(atom_feats)
34293433

3430-
symmetrized_pairwise_repr = pairwise_repr + rearrange(pairwise_repr, 'b i j d -> b j i d')
3431-
logits = self.to_distogram_logits(symmetrized_pairwise_repr)
3434+
logits = self.to_distogram_logits(symmetrize(pairwise_repr))
34323435

34333436
return logits
34343437

@@ -3586,8 +3589,7 @@ def forward(
35863589

35873590
# to logits
35883591

3589-
symmetric_pairwise_repr = pairwise_repr + rearrange(pairwise_repr, 'b i j d -> b j i d')
3590-
pde_logits = self.to_pde_logits(symmetric_pairwise_repr)
3592+
pde_logits = self.to_pde_logits(symmetrize(pairwise_repr))
35913593

35923594
plddt_logits = self.to_plddt_logits(single_repr)
35933595
resolved_logits = self.to_resolved_logits(single_repr)

0 commit comments

Comments
 (0)