107107 calculate_weighted_rigid_align_weights ,
108108 pack_one
109109)
110+ from alphafold3_pytorch .utils .utils import get_gpu_type , not_exists
110111
111112from alphafold3_pytorch .utils .model_utils import distance_to_dgram
112113
208209
209210# NOTE: for some types of (e.g., AMD ROCm) GPUs, this represents
210211# the maximum number of elements that can be processed simultaneously
211- # by backpropagation for a given loss tensor
212- MAX_ELEMENTS_FOR_BACKPROP = int (2e8 )
212+ # for a given tensor. For reference, see https://github.com/pytorch/pytorch/issues/136291.
213+ MAX_CONCURRENT_TENSOR_ELEMENTS = int (2e9 ) if "ROCm" in get_gpu_type () else float ( "inf" )
213214
214215LinearNoBias = partial (Linear , bias = False )
215216
@@ -756,47 +757,43 @@ def forward(
756757# triangle is axial attention w/ itself projected for bias
757758
758759class AttentionPairBias (Module ):
759- def __init__ (
760- self ,
761- * ,
762- heads ,
763- dim_pairwise ,
764- window_size = None ,
765- num_memory_kv = 0 ,
766- ** attn_kwargs
767- ):
760+ """An Attention module with pair bias computation."""
761+
762+ def __init__ (self , * , heads , dim_pairwise , window_size = None , num_memory_kv = 0 , ** attn_kwargs ):
768763 super ().__init__ ()
769764
770765 self .window_size = window_size
771766
772767 self .attn = Attention (
773- heads = heads ,
774- window_size = window_size ,
775- num_memory_kv = num_memory_kv ,
776- ** attn_kwargs
768+ heads = heads , window_size = window_size , num_memory_kv = num_memory_kv , ** attn_kwargs
777769 )
778770
779771 # line 8 of Algorithm 24
780772
781773 to_attn_bias_linear = LinearNoBias (dim_pairwise , heads )
782774 nn .init .zeros_ (to_attn_bias_linear .weight )
783775
784- self .to_attn_bias = nn .Sequential (
785- nn .LayerNorm (dim_pairwise ),
786- to_attn_bias_linear ,
787- Rearrange ('b ... h -> b h ...' )
788- )
776+ self .to_attn_bias_norm = nn .LayerNorm (dim_pairwise )
777+ self .to_attn_bias = nn .Sequential (to_attn_bias_linear , Rearrange ("b ... h -> b h ..." ))
789778
790779 @typecheck
791780 def forward (
792781 self ,
793- single_repr : Float [' b n ds' ],
782+ single_repr : Float [" b n ds" ], # type: ignore
794783 * ,
795- pairwise_repr : Float ['b n n dp' ] | Float ['b nw w (w*2) dp' ],
796- attn_bias : Float ['b n n' ] | Float ['b nw w (w*2)' ] | None = None ,
797- ** kwargs
798- ) -> Float ['b n ds' ]:
784+ pairwise_repr : Float ["b n n dp" ] | Float ["b nw w (w*2) dp" ], # type: ignore
785+ attn_bias : Float ["b n n" ] | Float ["b nw w (w*2)" ] | None = None , # type: ignore
786+ ** kwargs ,
787+ ) -> Float ["b n ds" ]: # type: ignore
788+ """Perform the forward pass.
799789
790+ :param single_repr: The single representation tensor.
791+ :param pairwise_repr: The pairwise representation tensor.
792+ :param attn_bias: The attention bias tensor.
793+ :return: The output tensor.
794+ """
795+ b , dp = pairwise_repr .shape [0 ], pairwise_repr .shape [- 1 ]
796+ dtype , device = pairwise_repr .dtype , pairwise_repr .device
800797 w , has_window_size = self .window_size , exists (self .window_size )
801798
802799 # take care of windowing logic
@@ -811,27 +808,42 @@ def forward(
811808
812809 if has_window_size :
813810 if not windowed_pairwise :
814- pairwise_repr = full_pairwise_repr_to_windowed (pairwise_repr , window_size = w )
811+ pairwise_repr = full_pairwise_repr_to_windowed (pairwise_repr , window_size = w )
815812 if exists (attn_bias ):
816- attn_bias = full_attn_bias_to_windowed (attn_bias , window_size = w )
813+ attn_bias = full_attn_bias_to_windowed (attn_bias , window_size = w )
817814 else :
818- assert not windowed_pairwise , 'cannot pass in windowed pairwise repr if no window_size given to AttentionPairBias'
819- assert not exists (windowed_attn_bias ) or not windowed_attn_bias , 'cannot pass in windowed attention bias if no window_size set for AttentionPairBias'
815+ assert (
816+ not windowed_pairwise
817+ ), "Cannot pass in windowed pairwise representation if no `window_size` given to `AttentionPairBias`."
818+ assert (
819+ not_exists (windowed_attn_bias ) or not windowed_attn_bias
820+ ), "Cannot pass in windowed attention bias if no `window_size` is set for `AttentionPairBias`."
820821
821822 # attention bias preparation with further addition from pairwise repr
822823
823824 if exists (attn_bias ):
824- attn_bias = rearrange (attn_bias , ' b ... -> b 1 ...' )
825+ attn_bias = rearrange (attn_bias , " b ... -> b 1 ..." )
825826 else :
826- attn_bias = 0.
827+ attn_bias = 0.0
828+
829+ if pairwise_repr .numel () > MAX_CONCURRENT_TENSOR_ELEMENTS :
830+ # create a stub tensor and normalize it to maintain gradients to `to_attn_bias_norm`
831+ stub_pairwise_repr = torch .zeros ((b , dp ), dtype = dtype , device = device )
832+ stub_attn_bias_norm = self .to_attn_bias_norm (stub_pairwise_repr ) * 0.0
833+
834+ # adjust `attn_bias_norm` dimensions to match `pairwise_repr`
835+ attn_bias_norm = pairwise_repr + (
836+ stub_attn_bias_norm [:, None , None , None , :]
837+ if windowed_pairwise
838+ else stub_attn_bias_norm [:, None , None , :]
839+ )
827840
828- attn_bias = self .to_attn_bias (pairwise_repr ) + attn_bias
841+ # apply bias transformation
842+ attn_bias = self .to_attn_bias (attn_bias_norm ) + attn_bias
843+ else :
844+ attn_bias = self .to_attn_bias (self .to_attn_bias_norm (pairwise_repr )) + attn_bias
829845
830- out = self .attn (
831- single_repr ,
832- attn_bias = attn_bias ,
833- ** kwargs
834- )
846+ out = self .attn (single_repr , attn_bias = attn_bias , ** kwargs )
835847
836848 return out
837849
@@ -2919,7 +2931,7 @@ def forward(
29192931 bond_losses = F .mse_loss (denoised_cdist , normalized_cdist , reduction = 'none' )
29202932 bond_losses = bond_losses * loss_weights
29212933
2922- if atompair_mask .sum () > MAX_ELEMENTS_FOR_BACKPROP :
2934+ if atompair_mask .sum () > MAX_CONCURRENT_TENSOR_ELEMENTS :
29232935 if verbose :
29242936 logger .info ("Subsetting atom pairs for backprop within EDM" )
29252937
@@ -2928,7 +2940,7 @@ def forward(
29282940 flat_atompair_mask_indices = torch .arange (atompair_mask .numel (), device = self .device )[atompair_mask .view (- 1 )]
29292941 num_true_atompairs = flat_atompair_mask_indices .size (0 )
29302942
2931- num_atompairs_to_ignore = num_true_atompairs - MAX_ELEMENTS_FOR_BACKPROP
2943+ num_atompairs_to_ignore = num_true_atompairs - MAX_CONCURRENT_TENSOR_ELEMENTS
29322944 ignored_atompair_indices = flat_atompair_mask_indices [torch .randperm (num_true_atompairs )[:num_atompairs_to_ignore ]]
29332945
29342946 atompair_mask .view (- 1 )[ignored_atompair_indices ] = False
0 commit comments