5555
5656# constants
5757
58+ DIM_ADDITIONAL_RESIDUE_FEATS = 10
59+
5860LinearNoBias = partial (Linear , bias = False )
5961
6062# helper functions
@@ -1435,7 +1437,8 @@ def forward(
14351437 w = self .atoms_per_window
14361438 is_unpacked_repr = exists (w )
14371439
1438- assert is_unpacked_repr ^ exists (residue_atom_lens ), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'
1440+ if not is_unpacked_repr :
1441+ assert exists (residue_atom_lens ), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'
14391442
14401443 atom_feats = self .proj (atom_feats )
14411444
@@ -1613,7 +1616,8 @@ def forward(
16131616 w = self .atoms_per_window
16141617 is_unpacked_repr = exists (w )
16151618
1616- assert is_unpacked_repr ^ exists (residue_atom_lens )
1619+ if not is_unpacked_repr :
1620+ assert exists (residue_atom_lens )
16171621
16181622 # in the paper, it seems they pack the atom feats
16191623 # but in this impl, will just use windows for simplicity when communicating between atom and residue resolutions. bit less efficient
@@ -2350,7 +2354,6 @@ def __init__(
23502354 self ,
23512355 * ,
23522356 dim_atom_inputs ,
2353- dim_additional_residue_feats = 10 ,
23542357 atoms_per_window = 27 ,
23552358 dim_atom = 128 ,
23562359 dim_atompair = 16 ,
@@ -2396,9 +2399,7 @@ def __init__(
23962399 atoms_per_window = atoms_per_window
23972400 )
23982401
2399- dim_single_input = dim_token + dim_additional_residue_feats
2400-
2401- self .dim_additional_residue_feats = dim_additional_residue_feats
2402+ dim_single_input = dim_token + DIM_ADDITIONAL_RESIDUE_FEATS
24022403
24032404 self .single_input_to_single_init = LinearNoBias (dim_single_input , dim_single )
24042405 self .single_input_to_pairwise_init = LinearNoBiasThenOuterSum (dim_single_input , dim_pairwise )
@@ -2415,7 +2416,7 @@ def forward(
24152416
24162417 ) -> EmbeddedInputs :
24172418
2418- assert additional_residue_feats .shape [- 1 ] == self . dim_additional_residue_feats
2419+ assert additional_residue_feats .shape [- 1 ] == DIM_ADDITIONAL_RESIDUE_FEATS
24192420
24202421 w = self .atoms_per_window
24212422
@@ -2608,7 +2609,6 @@ def __init__(
26082609 self ,
26092610 * ,
26102611 dim_atom_inputs ,
2611- dim_additional_residue_feats ,
26122612 dim_template_feats ,
26132613 dim_template_model = 64 ,
26142614 atoms_per_window = 27 ,
@@ -2713,7 +2713,6 @@ def __init__(
27132713
27142714 self .input_embedder = InputFeatureEmbedder (
27152715 dim_atom_inputs = dim_atom_inputs ,
2716- dim_additional_residue_feats = dim_additional_residue_feats ,
27172716 atoms_per_window = atoms_per_window ,
27182717 dim_atom = dim_atom ,
27192718 dim_atompair = dim_atompair ,
@@ -2723,7 +2722,7 @@ def __init__(
27232722 ** input_embedder_kwargs
27242723 )
27252724
2726- dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats
2725+ dim_single_inputs = dim_input_embedder_token + DIM_ADDITIONAL_RESIDUE_FEATS
27272726
27282727 # relative positional encoding
27292728 # used by pairwise in main alphafold2 trunk
@@ -2866,22 +2865,28 @@ def forward(
28662865
28672866 atom_seq_len = atom_inputs .shape [- 2 ]
28682867
2868+ assert exists (residue_atom_lens ) or exists (atom_mask )
2869+
28692870 # determine whether using packed or unpacked atom rep
28702871
2871- assert exists (residue_atom_lens ) ^ exists (atom_mask ), 'either atom_lens or atom_mask must be given depending on whether packed_atom_repr kwarg is True or False'
2872+ if self .packed_atom_repr :
2873+ assert exists (residue_atom_lens ), 'residue_atom_lens must be given if using packed atom repr'
28722874
28732875 if exists (residue_atom_lens ):
2874- assert self .packed_atom_repr , '`packed_atom_repr` kwarg on Alphafold3 must be True when passing in `atom_lens`'
28752876
2876- # handle atom mask
2877+ if self .packed_atom_repr :
2878+ # handle atom mask
28772879
2878- total_atoms = residue_atom_lens .sum (dim = - 1 )
2879- atom_mask = lens_to_mask (total_atoms , max_len = atom_seq_len )
2880+ total_atoms = residue_atom_lens .sum (dim = - 1 )
2881+ atom_mask = lens_to_mask (total_atoms , max_len = atom_seq_len )
28802882
2881- # handle offsets for residue atom indices
2883+ # handle offsets for residue atom indices
28822884
2883- if exists (residue_atom_indices ):
2884- residue_atom_indices += F .pad (residue_atom_lens , (- 1 , 1 ), value = 0 )
2885+ if exists (residue_atom_indices ):
2886+ residue_atom_indices += F .pad (residue_atom_lens , (- 1 , 1 ), value = 0 )
2887+ else :
2888+ atom_mask = lens_to_mask (residue_atom_lens , max_len = self .atoms_per_window )
2889+ atom_mask = rearrange (atom_mask , 'b ... -> b (...)' )
28852890
28862891 # get atom sequence length and residue sequence length depending on whether using packed atomic seq
28872892
0 commit comments