Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions tritonbench/operators/gdpa/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from triton.language.math import fast_dividef, fast_expf


HAS_FAST_TANH_INSTRUCTION = (
torch.version.cuda is not None and torch.cuda.get_device_capability()[0] >= 9
) # H100


# Don't change the order of the enum values, as they are used to index
# Only add new activation functions at the end of the enum
class Activation(str, Enum):
Expand All @@ -50,17 +55,6 @@ def activation_string_to_int(s: str):
return activation_to_int.get(enum_val)


def is_hip_or_a100():
try:
if triton.runtime.driver.active.get_current_target().backend == "hip":
return True
elif torch.cuda.get_device_capability()[0] < 9: # A100
return True
return False
except Exception:
return False


@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
Expand All @@ -79,11 +73,11 @@ def gelu_grad(x):
return cdf + x * pdf


if is_hip_or_a100():
# For AMD or A100, use tanh as a fallback
if not HAS_FAST_TANH_INSTRUCTION:

@triton.jit
def tanh_approx_fp32(x):
return tanh(x)
return 2 * fast_dividef(1.0, 1.0 + fast_expf(-2.0 * x)) - 1.0

else:

Expand Down