Skip to content

Commit 2865d1b

Browse files
authored
Restrict numeric workarounds to ROCm GPUs (#298)
* Update alphafold3.py * Update utils.py * Update alphafold3.py
1 parent 165d795 commit 2865d1b

File tree

2 files changed

+67
-39
lines changed

2 files changed

+67
-39
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
calculate_weighted_rigid_align_weights,
108108
pack_one
109109
)
110+
from alphafold3_pytorch.utils.utils import get_gpu_type, not_exists
110111

111112
from alphafold3_pytorch.utils.model_utils import distance_to_dgram
112113

@@ -208,8 +209,8 @@
208209

209210
# NOTE: for some types of (e.g., AMD ROCm) GPUs, this represents
210211
# the maximum number of elements that can be processed simultaneously
211-
# by backpropagation for a given loss tensor
212-
MAX_ELEMENTS_FOR_BACKPROP = int(2e8)
212+
# for a given tensor. For reference, see https://github.com/pytorch/pytorch/issues/136291.
213+
MAX_CONCURRENT_TENSOR_ELEMENTS = int(2e9) if "ROCm" in get_gpu_type() else float("inf")
213214

214215
LinearNoBias = partial(Linear, bias = False)
215216

@@ -756,47 +757,43 @@ def forward(
756757
# triangle is axial attention w/ itself projected for bias
757758

758759
class AttentionPairBias(Module):
759-
def __init__(
760-
self,
761-
*,
762-
heads,
763-
dim_pairwise,
764-
window_size = None,
765-
num_memory_kv = 0,
766-
**attn_kwargs
767-
):
760+
"""An Attention module with pair bias computation."""
761+
762+
def __init__(self, *, heads, dim_pairwise, window_size=None, num_memory_kv=0, **attn_kwargs):
768763
super().__init__()
769764

770765
self.window_size = window_size
771766

772767
self.attn = Attention(
773-
heads = heads,
774-
window_size = window_size,
775-
num_memory_kv = num_memory_kv,
776-
**attn_kwargs
768+
heads=heads, window_size=window_size, num_memory_kv=num_memory_kv, **attn_kwargs
777769
)
778770

779771
# line 8 of Algorithm 24
780772

781773
to_attn_bias_linear = LinearNoBias(dim_pairwise, heads)
782774
nn.init.zeros_(to_attn_bias_linear.weight)
783775

784-
self.to_attn_bias = nn.Sequential(
785-
nn.LayerNorm(dim_pairwise),
786-
to_attn_bias_linear,
787-
Rearrange('b ... h -> b h ...')
788-
)
776+
self.to_attn_bias_norm = nn.LayerNorm(dim_pairwise)
777+
self.to_attn_bias = nn.Sequential(to_attn_bias_linear, Rearrange("b ... h -> b h ..."))
789778

790779
@typecheck
791780
def forward(
792781
self,
793-
single_repr: Float['b n ds'],
782+
single_repr: Float["b n ds"], # type: ignore
794783
*,
795-
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
796-
attn_bias: Float['b n n'] | Float['b nw w (w*2)'] | None = None,
797-
**kwargs
798-
) -> Float['b n ds']:
784+
pairwise_repr: Float["b n n dp"] | Float["b nw w (w*2) dp"], # type: ignore
785+
attn_bias: Float["b n n"] | Float["b nw w (w*2)"] | None = None, # type: ignore
786+
**kwargs,
787+
) -> Float["b n ds"]: # type: ignore
788+
"""Perform the forward pass.
799789
790+
:param single_repr: The single representation tensor.
791+
:param pairwise_repr: The pairwise representation tensor.
792+
:param attn_bias: The attention bias tensor.
793+
:return: The output tensor.
794+
"""
795+
b, dp = pairwise_repr.shape[0], pairwise_repr.shape[-1]
796+
dtype, device = pairwise_repr.dtype, pairwise_repr.device
800797
w, has_window_size = self.window_size, exists(self.window_size)
801798

802799
# take care of windowing logic
@@ -811,27 +808,42 @@ def forward(
811808

812809
if has_window_size:
813810
if not windowed_pairwise:
814-
pairwise_repr = full_pairwise_repr_to_windowed(pairwise_repr, window_size = w)
811+
pairwise_repr = full_pairwise_repr_to_windowed(pairwise_repr, window_size=w)
815812
if exists(attn_bias):
816-
attn_bias = full_attn_bias_to_windowed(attn_bias, window_size = w)
813+
attn_bias = full_attn_bias_to_windowed(attn_bias, window_size=w)
817814
else:
818-
assert not windowed_pairwise, 'cannot pass in windowed pairwise repr if no window_size given to AttentionPairBias'
819-
assert not exists(windowed_attn_bias) or not windowed_attn_bias, 'cannot pass in windowed attention bias if no window_size set for AttentionPairBias'
815+
assert (
816+
not windowed_pairwise
817+
), "Cannot pass in windowed pairwise representation if no `window_size` given to `AttentionPairBias`."
818+
assert (
819+
not_exists(windowed_attn_bias) or not windowed_attn_bias
820+
), "Cannot pass in windowed attention bias if no `window_size` is set for `AttentionPairBias`."
820821

821822
# attention bias preparation with further addition from pairwise repr
822823

823824
if exists(attn_bias):
824-
attn_bias = rearrange(attn_bias, 'b ... -> b 1 ...')
825+
attn_bias = rearrange(attn_bias, "b ... -> b 1 ...")
825826
else:
826-
attn_bias = 0.
827+
attn_bias = 0.0
828+
829+
if pairwise_repr.numel() > MAX_CONCURRENT_TENSOR_ELEMENTS:
830+
# create a stub tensor and normalize it to maintain gradients to `to_attn_bias_norm`
831+
stub_pairwise_repr = torch.zeros((b, dp), dtype=dtype, device=device)
832+
stub_attn_bias_norm = self.to_attn_bias_norm(stub_pairwise_repr) * 0.0
833+
834+
# adjust `attn_bias_norm` dimensions to match `pairwise_repr`
835+
attn_bias_norm = pairwise_repr + (
836+
stub_attn_bias_norm[:, None, None, None, :]
837+
if windowed_pairwise
838+
else stub_attn_bias_norm[:, None, None, :]
839+
)
827840

828-
attn_bias = self.to_attn_bias(pairwise_repr) + attn_bias
841+
# apply bias transformation
842+
attn_bias = self.to_attn_bias(attn_bias_norm) + attn_bias
843+
else:
844+
attn_bias = self.to_attn_bias(self.to_attn_bias_norm(pairwise_repr)) + attn_bias
829845

830-
out = self.attn(
831-
single_repr,
832-
attn_bias = attn_bias,
833-
**kwargs
834-
)
846+
out = self.attn(single_repr, attn_bias=attn_bias, **kwargs)
835847

836848
return out
837849

@@ -2919,7 +2931,7 @@ def forward(
29192931
bond_losses = F.mse_loss(denoised_cdist, normalized_cdist, reduction = 'none')
29202932
bond_losses = bond_losses * loss_weights
29212933

2922-
if atompair_mask.sum() > MAX_ELEMENTS_FOR_BACKPROP:
2934+
if atompair_mask.sum() > MAX_CONCURRENT_TENSOR_ELEMENTS:
29232935
if verbose:
29242936
logger.info("Subsetting atom pairs for backprop within EDM")
29252937

@@ -2928,7 +2940,7 @@ def forward(
29282940
flat_atompair_mask_indices = torch.arange(atompair_mask.numel(), device=self.device)[atompair_mask.view(-1)]
29292941
num_true_atompairs = flat_atompair_mask_indices.size(0)
29302942

2931-
num_atompairs_to_ignore = num_true_atompairs - MAX_ELEMENTS_FOR_BACKPROP
2943+
num_atompairs_to_ignore = num_true_atompairs - MAX_CONCURRENT_TENSOR_ELEMENTS
29322944
ignored_atompair_indices = flat_atompair_mask_indices[torch.randperm(num_true_atompairs)[:num_atompairs_to_ignore]]
29332945

29342946
atompair_mask.view(-1)[ignored_atompair_indices] = False

alphafold3_pytorch/utils/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
import numpy as np
24

35
from beartype.typing import Any, Iterable, List
@@ -61,3 +63,17 @@ def np_mode(x: np.ndarray) -> Any:
6163
values, counts = np.unique(x, return_counts=True)
6264
m = counts.argmax()
6365
return values[m], counts[m]
66+
67+
68+
def get_gpu_type() -> str:
69+
"""Return the type of GPU detected: NVIDIA, ROCm, or Unknown."""
70+
if torch.cuda.is_available():
71+
device_name = torch.cuda.get_device_name(0).lower()
72+
if "nvidia" in device_name:
73+
return "NVIDIA GPU detected"
74+
elif "amd" in device_name or "gfx" in device_name:
75+
return "ROCm GPU detected"
76+
else:
77+
return "Unknown GPU type"
78+
else:
79+
return "No GPU available"

0 commit comments

Comments
 (0)