@@ -131,12 +131,13 @@ def mean_pool_with_lens(
131131
132132@typecheck
133133def repeat_consecutive_with_lens (
134- feats : Float ['b n d ' ],
134+ feats : Float ['b n ...' ] | Bool [ 'b n ' ],
135135 lens : Int ['b n' ],
136136 max_length : int | None = None ,
137137 return_mask = False
138- ) -> Float ['b m d' ] | Tuple [Float ['b m d' ], Bool ['b m' ]]:
138+ ) -> Float ['b m d' ] | Bool [ 'b m' ] | Tuple [Float ['b m d' ] | Bool [ 'b m ' ], Bool ['b m' ]]:
139139
140+ is_bool = feats .dtype == torch .bool
140141 device = feats .device
141142
142143 # derive arange from the max length
@@ -165,8 +166,11 @@ def repeat_consecutive_with_lens(
165166
166167 # now broadcast and sum for consecutive features
167168
168- feats = einx .multiply ('b n d, b n m -> b n m d' , feats , consecutive_mask .float ())
169- feats = reduce (feats , 'b n m d -> b m d' , 'sum' )
169+ feats = einx .multiply ('b n ..., b n m -> b n m ...' , feats , consecutive_mask .float ())
170+ feats = reduce (feats , 'b n m ... -> b m ...' , 'sum' )
171+
172+ if is_bool :
173+ feats = feats .bool ()
170174
171175 if not return_mask :
172176 return feats
@@ -1373,12 +1377,24 @@ def forward(
13731377 self ,
13741378 * ,
13751379 atom_feats : Float ['b m da' ],
1376- atom_mask : Bool ['b m' ]
1380+ atom_mask : Bool ['b m' ],
1381+ residue_atom_lens : Int ['b n' ] | None = None
13771382 ) -> Float ['b n ds' ]:
1383+
13781384 w = self .atoms_per_window
1385+ is_unpacked_repr = exists (w )
1386+
1387+ assert is_unpacked_repr ^ exists (residue_atom_lens ), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'
13791388
13801389 atom_feats = self .proj (atom_feats )
13811390
1391+ # packed atom representation
1392+
1393+ if exists (residue_atom_lens ):
1394+ tokens = mean_pool_with_lens (atom_feats , residue_atom_lens )
1395+ return tokens
1396+
1397+ # unpacked atom representation
13821398 # masked mean pool the atom feats for each residue, for the token transformer
13831399 # this is basically a simple 2-level hierarchical transformer
13841400
@@ -1541,13 +1557,18 @@ def forward(
15411557 single_inputs_repr : Float ['b n dsi' ],
15421558 pairwise_trunk : Float ['b n n dpt' ],
15431559 pairwise_rel_pos_feats : Float ['b n n dpr' ],
1560+ residue_atom_lens : Int ['b n' ] | None = None
15441561 ):
15451562 w = self .atoms_per_window
1563+ is_unpacked_repr = exists (w )
1564+
1565+ assert is_unpacked_repr ^ exists (residue_atom_lens )
15461566
15471567 # in the paper, it seems they pack the atom feats
15481568 # but in this impl, will just use windows for simplicity when communicating between atom and residue resolutions. bit less efficient
15491569
1550- assert divisible_by (noised_atom_pos .shape [- 2 ], w )
1570+ if is_unpacked_repr :
1571+ assert divisible_by (noised_atom_pos .shape [- 2 ], w )
15511572
15521573 conditioned_single_repr = self .single_conditioner (
15531574 times = times ,
@@ -1572,14 +1593,20 @@ def forward(
15721593
15731594 single_repr_cond = self .single_repr_to_atom_feat_cond (conditioned_single_repr )
15741595
1575- single_repr_cond = repeat (single_repr_cond , 'b n ds -> b (n w) ds' , w = w )
1596+ if is_unpacked_repr :
1597+ single_repr_cond = repeat (single_repr_cond , 'b n ds -> b (n w) ds' , w = w )
1598+ else :
1599+ single_repr_cond = repeat_consecutive_with_lens (single_repr_cond , residue_atom_lens )
1600+
15761601 atom_feats_cond = single_repr_cond + atom_feats_cond
15771602
15781603 # condition atompair feats with pairwise repr
15791604
15801605 pairwise_repr_cond = self .pairwise_repr_to_atompair_feat_cond (conditioned_pairwise_repr )
1581- pairwise_repr_cond = repeat (pairwise_repr_cond , 'b i j dp -> b (i w1) (j w2) dp' , w1 = w , w2 = w )
1582- atompair_feats = pairwise_repr_cond + atompair_feats
1606+
1607+ if is_unpacked_repr :
1608+ pairwise_repr_cond = repeat (pairwise_repr_cond , 'b i j dp -> b (i w1) (j w2) dp' , w1 = w , w2 = w )
1609+ atompair_feats = pairwise_repr_cond + atompair_feats
15831610
15841611 # condition atompair feats further with single atom repr
15851612
@@ -1603,8 +1630,10 @@ def forward(
16031630
16041631 tokens = self .atom_feats_to_pooled_token (
16051632 atom_feats = atom_feats ,
1606- atom_mask = atom_mask
1633+ atom_mask = atom_mask ,
1634+ residue_atom_lens = residue_atom_lens
16071635 )
1636+
16081637 # token transformer
16091638
16101639 tokens = self .cond_tokens_with_cond_single (conditioned_single_repr ) + tokens
@@ -1621,7 +1650,11 @@ def forward(
16211650 # atom decoder
16221651
16231652 atom_decoder_input = self .tokens_to_atom_decoder_input_cond (tokens )
1624- atom_decoder_input = repeat (atom_decoder_input , 'b n da -> b (n w) da' , w = w )
1653+
1654+ if is_unpacked_repr :
1655+ atom_decoder_input = repeat (atom_decoder_input , 'b n da -> b (n w) da' , w = w )
1656+ else :
1657+ atom_decoder_input = repeat_consecutive_with_lens (atom_decoder_input , residue_atom_lens )
16251658
16261659 atom_decoder_input = atom_decoder_input + atom_feats_skip
16271660
@@ -1771,7 +1804,7 @@ def sample_schedule(self, num_sample_steps = None):
17711804 @torch .no_grad ()
17721805 def sample (
17731806 self ,
1774- atom_mask : Bool ['b m' ],
1807+ atom_mask : Bool ['b m' ] | None = None ,
17751808 num_sample_steps = None ,
17761809 clamp = True ,
17771810 ** network_condition_kwargs
@@ -1847,6 +1880,7 @@ def forward(
18471880 pairwise_trunk : Float ['b n n dpt' ],
18481881 pairwise_rel_pos_feats : Float ['b n n dpr' ],
18491882 return_denoised_pos = False ,
1883+ residue_atom_lens : Int ['b n' ] | None = None ,
18501884 additional_residue_feats : Float ['b n 10' ] | None = None ,
18511885 add_smooth_lddt_loss = False ,
18521886 add_bond_loss = False ,
@@ -1878,6 +1912,7 @@ def forward(
18781912 single_inputs_repr = single_inputs_repr ,
18791913 pairwise_trunk = pairwise_trunk ,
18801914 pairwise_rel_pos_feats = pairwise_rel_pos_feats ,
1915+ residue_atom_lens = residue_atom_lens
18811916 )
18821917 )
18831918
@@ -1890,9 +1925,14 @@ def forward(
18901925
18911926 if exists (additional_residue_feats ):
18921927 w = self .net .atoms_per_window
1928+ is_unpacked_repr = exists (w )
18931929
1894- is_nucleotide_or_ligand_fields = additional_residue_feats [..., 7 :] != 0.
1895- atom_is_dna , atom_is_rna , atom_is_ligand = tuple (repeat (t != 0. , 'b n -> b (n w)' , w = w ) for t in is_nucleotide_or_ligand_fields .unbind (dim = - 1 ))
1930+ is_nucleotide_or_ligand_fields = (additional_residue_feats [..., 7 :] != 0. ).unbind (dim = - 1 )
1931+
1932+ if is_unpacked_repr :
1933+ atom_is_dna , atom_is_rna , atom_is_ligand = tuple (repeat (t != 0. , 'b n -> b (n w)' , w = w ) for t in is_nucleotide_or_ligand_fields )
1934+ else :
1935+ atom_is_dna , atom_is_rna , atom_is_ligand = tuple (repeat_consecutive_with_lens (t , residue_atom_lens ) for t in is_nucleotide_or_ligand_fields )
18961936
18971937 # section 3.7.1 equation 4
18981938
@@ -2315,6 +2355,8 @@ def forward(
23152355 atom_mask : Bool ['b m' ],
23162356 atompair_feats : Float ['b m m dap' ],
23172357 additional_residue_feats : Float ['b n rf' ],
2358+ residue_atom_lens : Int ['b n' ] | None = None ,
2359+
23182360 ) -> EmbeddedInputs :
23192361
23202362 assert additional_residue_feats .shape [- 1 ] == self .dim_additional_residue_feats
@@ -2336,7 +2378,8 @@ def forward(
23362378
23372379 single_inputs = self .atom_feats_to_pooled_token (
23382380 atom_feats = atom_feats ,
2339- atom_mask = atom_mask
2381+ atom_mask = atom_mask ,
2382+ residue_atom_lens = residue_atom_lens
23402383 )
23412384
23422385 single_inputs = torch .cat ((single_inputs , additional_residue_feats ), dim = - 1 )
@@ -2527,6 +2570,7 @@ def __init__(
25272570 num_pae_bins = 64 ,
25282571 sigma_data = 16 ,
25292572 diffusion_num_augmentations = 4 ,
2573+ packed_atom_repr = False ,
25302574 loss_confidence_weight = 1e-4 ,
25312575 loss_distogram_weight = 1e-2 ,
25322576 loss_diffusion_weight = 4. ,
@@ -2593,6 +2637,15 @@ def __init__(
25932637 ):
25942638 super ().__init__ ()
25952639
2640+ # whether a packed atom representation is being used
2641+
2642+ self .packed_atom_repr = packed_atom_repr
2643+
2644+ # atoms per window if using unpacked representation
2645+
2646+ if packed_atom_repr :
2647+ atoms_per_window = None
2648+
25962649 self .atoms_per_window = atoms_per_window
25972650
25982651 # augmentation
@@ -2732,9 +2785,10 @@ def forward(
27322785 self ,
27332786 * ,
27342787 atom_inputs : Float ['b m dai' ],
2735- atom_mask : Bool ['b m' ],
27362788 atompair_feats : Float ['b m m dap' ],
27372789 additional_residue_feats : Float ['b n 10' ],
2790+ residue_atom_lens : Int ['b n' ] | None = None ,
2791+ atom_mask : Bool ['b m' ] | None = None ,
27382792 token_bond : Bool ['b n n' ] | None = None ,
27392793 msa : Float ['b s n d' ] | None = None ,
27402794 msa_mask : Bool ['b s' ] | None = None ,
@@ -2754,13 +2808,33 @@ def forward(
27542808 return_loss_breakdown = False
27552809 ) -> Float ['b m 3' ] | Float ['' ] | Tuple [Float ['' ], LossBreakdown ]:
27562810
2757- # get atom sequence length and residue sequence length
2758-
2759- w = self .atoms_per_window
27602811 atom_seq_len = atom_inputs .shape [- 2 ]
27612812
2762- assert divisible_by (atom_seq_len , w )
2763- seq_len = atom_inputs .shape [- 2 ] // w
2813+ # determine whether using packed or unpacked atom rep
2814+
2815+ assert exists (residue_atom_lens ) ^ exists (atom_mask ), 'either atom_lens or atom_mask must be given depending on whether packed_atom_repr kwarg is True or False'
2816+
2817+ if exists (residue_atom_lens ):
2818+ assert self .packed_atom_repr , '`packed_atom_repr` kwarg on Alphafold3 must be True when passing in `atom_lens`'
2819+
2820+ # handle atom mask
2821+
2822+ atom_mask = lens_to_mask (residue_atom_lens )
2823+ atom_mask = atom_mask [:, :atom_seq_len ]
2824+
2825+ # handle offsets for residue atom indices
2826+
2827+ if exists (residue_atom_indices ):
2828+ residue_atom_indices += F .pad (residue_atom_lens , (- 1 , 1 ), value = 0 )
2829+
2830+ # get atom sequence length and residue sequence length depending on whether using packed atomic seq
2831+
2832+ if self .packed_atom_repr :
2833+ seq_len = residue_atom_lens .shape [- 1 ]
2834+ else :
2835+ w = self .atoms_per_window
2836+ assert divisible_by (atom_seq_len , w )
2837+ seq_len = atom_inputs .shape [- 2 ] // w
27642838
27652839 # embed inputs
27662840
@@ -2774,7 +2848,8 @@ def forward(
27742848 atom_inputs = atom_inputs ,
27752849 atom_mask = atom_mask ,
27762850 atompair_feats = atompair_feats ,
2777- additional_residue_feats = additional_residue_feats
2851+ additional_residue_feats = additional_residue_feats ,
2852+ residue_atom_lens = residue_atom_lens
27782853 )
27792854
27802855 # relative positional encoding
@@ -2805,7 +2880,12 @@ def forward(
28052880
28062881 # pairwise mask
28072882
2808- mask = reduce (atom_mask , 'b (n w) -> b n' , w = w , reduction = 'any' )
2883+ if self .packed_atom_repr :
2884+ mask = lens_to_mask (residue_atom_lens )
2885+ mask = mask [:, :seq_len ]
2886+ else :
2887+ mask = reduce (atom_mask , 'b (n w) -> b n' , w = w , reduction = 'any' )
2888+
28092889 pairwise_mask = einx .logical_and ('b i, b j -> b i j' , mask , mask )
28102890
28112891 # init recycled single and pairwise
@@ -2889,7 +2969,8 @@ def forward(
28892969 single_trunk_repr = single ,
28902970 single_inputs_repr = single_inputs ,
28912971 pairwise_trunk = pairwise ,
2892- pairwise_rel_pos_feats = relative_position_encoding
2972+ pairwise_rel_pos_feats = relative_position_encoding ,
2973+ residue_atom_lens = residue_atom_lens
28932974 )
28942975
28952976 # losses default to 0
@@ -2903,7 +2984,12 @@ def forward(
29032984 # distogram head
29042985
29052986 if not exists (distance_labels ) and atom_pos_given and exists (residue_atom_indices ):
2906- residue_pos = einx .get_at ('b (n [w]) c, b n -> b n c' , atom_pos , residue_atom_indices )
2987+
2988+ if self .packed_atom_repr :
2989+ residue_pos = einx .get_at ('b [m] c, b n -> b n c' , atom_pos , residue_atom_indices )
2990+ else :
2991+ residue_pos = einx .get_at ('b (n [w]) c, b n -> b n c' , atom_pos , residue_atom_indices )
2992+
29072993 residue_dist = torch .cdist (residue_pos , residue_pos , p = 2 )
29082994 dist_from_dist_bins = einx .subtract ('b m dist, dist_bins -> b m dist dist_bins' , residue_dist , self .distance_bins ).abs ()
29092995 distance_labels = dist_from_dist_bins .argmin (dim = - 1 )
@@ -2938,6 +3024,7 @@ def forward(
29383024 relative_position_encoding ,
29393025 additional_residue_feats ,
29403026 residue_atom_indices ,
3027+ residue_atom_lens ,
29413028 pae_labels ,
29423029 pde_labels ,
29433030 plddt_labels ,
@@ -2958,6 +3045,7 @@ def forward(
29583045 relative_position_encoding ,
29593046 additional_residue_feats ,
29603047 residue_atom_indices ,
3048+ residue_atom_lens ,
29613049 pae_labels ,
29623050 pde_labels ,
29633051 plddt_labels ,
@@ -2980,6 +3068,7 @@ def forward(
29803068 single_inputs_repr = single_inputs ,
29813069 pairwise_trunk = pairwise ,
29823070 pairwise_rel_pos_feats = relative_position_encoding ,
3071+ residue_atom_lens = residue_atom_lens ,
29833072 return_denoised_pos = True ,
29843073 )
29853074
@@ -2990,7 +3079,10 @@ def forward(
29903079
29913080 if calc_diffusion_loss and should_call_confidence_head :
29923081
2993- pred_atom_pos = einx .get_at ('b (n [w]) c, b n -> b n c' , denoised_atom_pos , residue_atom_indices )
3082+ if self .packed_atom_repr :
3083+ pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , denoised_atom_pos , residue_atom_indices )
3084+ else :
3085+ pred_atom_pos = einx .get_at ('b (n [w]) c, b n -> b n c' , denoised_atom_pos , residue_atom_indices )
29943086
29953087 logits = self .confidence_head (
29963088 single_repr = single ,
0 commit comments