From 3aeb7fb374e23ab044b0d68a2113b406d6112412 Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Thu, 2 Oct 2025 16:56:22 -0700 Subject: [PATCH] Use fastmath in GeLU (TritonBench) (#506) Summary: Use fastmath intrinsics in `tanh_approx_fp32` when we don't have a native fast tanh instruction. We need to use the sigmoid formulation rather than dividing two exponents due to numeric stability. Differential Revision: D83082730 --- tritonbench/operators/gdpa/math.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tritonbench/operators/gdpa/math.py b/tritonbench/operators/gdpa/math.py index 7fedaf290..0151de7f0 100644 --- a/tritonbench/operators/gdpa/math.py +++ b/tritonbench/operators/gdpa/math.py @@ -28,6 +28,11 @@ from triton.language.math import fast_dividef, fast_expf +HAS_FAST_TANH_INSTRUCTION = ( + torch.version.cuda is not None and torch.cuda.get_device_capability()[0] >= 9 +) # H100 + + # Don't change the order of the enum values, as they are used to index # Only add new activation functions at the end of the enum class Activation(str, Enum): @@ -50,17 +55,6 @@ def activation_string_to_int(s: str): return activation_to_int.get(enum_val) -def is_hip_or_a100(): - try: - if triton.runtime.driver.active.get_current_target().backend == "hip": - return True - elif torch.cuda.get_device_capability()[0] < 9: # A100 - return True - return False - except Exception: - return False - - @triton.jit def tanh(x): # Tanh is just a scaled sigmoid @@ -79,11 +73,11 @@ def gelu_grad(x): return cdf + x * pdf -if is_hip_or_a100(): - # For AMD or A100, use tanh as a fallback +if not HAS_FAST_TANH_INSTRUCTION: + @triton.jit def tanh_approx_fp32(x): - return tanh(x) + return 2 * fast_dividef(1.0, 1.0 + fast_expf(-2.0 * x)) - 1.0 else: