diff --git a/tritonbench/operators/gdpa/math.py b/tritonbench/operators/gdpa/math.py index 7fedaf29..0151de7f 100644 --- a/tritonbench/operators/gdpa/math.py +++ b/tritonbench/operators/gdpa/math.py @@ -28,6 +28,11 @@ from triton.language.math import fast_dividef, fast_expf +HAS_FAST_TANH_INSTRUCTION = ( + torch.version.cuda is not None and torch.cuda.get_device_capability()[0] >= 9 +) # H100 + + # Don't change the order of the enum values, as they are used to index # Only add new activation functions at the end of the enum class Activation(str, Enum): @@ -50,17 +55,6 @@ def activation_string_to_int(s: str): return activation_to_int.get(enum_val) -def is_hip_or_a100(): - try: - if triton.runtime.driver.active.get_current_target().backend == "hip": - return True - elif torch.cuda.get_device_capability()[0] < 9: # A100 - return True - return False - except Exception: - return False - - @triton.jit def tanh(x): # Tanh is just a scaled sigmoid @@ -79,11 +73,11 @@ def gelu_grad(x): return cdf + x * pdf -if is_hip_or_a100(): - # For AMD or A100, use tanh as a fallback +if not HAS_FAST_TANH_INSTRUCTION: + @triton.jit def tanh_approx_fp32(x): - return tanh(x) + return 2 * fast_dividef(1.0, 1.0 + fast_expf(-2.0 * x)) - 1.0 else: