Skip to content

Conversation

@hankluo6
Copy link
Contributor

Fixed #148354

Lower SPIR-V Tan/Tanh ops using the corresponding LLVM intrinsics to reduce instructions and prevent overflow caused by the previous exp-based expansion.

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Hank (hankluo6)

Changes

Fixed #148354

Lower SPIR-V Tan/Tanh ops using the corresponding LLVM intrinsics to reduce instructions and prevent overflow caused by the previous exp-based expansion.


Full diff: https://github.com/llvm/llvm-project/pull/168419.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+4-22)
  • (modified) mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir (+2-10)
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 50fca564b5b64..02b61bd989368 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1520,20 +1520,12 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
     if (!dstType)
       return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
 
-    Location loc = tanOp.getLoc();
-    Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
-    Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
-    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
+    rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
+                                             adaptor.getOperands());
     return success();
   }
 };
 
-/// Convert `spirv.Tanh` to
-///
-///   exp(2x) - 1
-///   -----------
-///   exp(2x) + 1
-///
 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
 public:
   using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
@@ -1546,18 +1538,8 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
     if (!dstType)
       return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
 
-    Location loc = tanhOp.getLoc();
-    Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
-    Value multiplied =
-        LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
-    Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
-    Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
-    Value numerator =
-        LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
-    Value denominator =
-        LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
-    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
-                                              denominator);
+    rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
+                                              adaptor.getOperands());
     return success();
   }
 };
diff --git a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
index e1936e2fd8abe..b17e1c40cb9a7 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
@@ -162,9 +162,7 @@ spirv.func @sqrt(%arg0: f32, %arg1: vector<3xf16>) "None" {
 
 // CHECK-LABEL: @tan
 spirv.func @tan(%arg0: f32) "None" {
-  // CHECK: %[[SIN:.*]] = llvm.intr.sin(%{{.*}}) : (f32) -> f32
-  // CHECK: %[[COS:.*]] = llvm.intr.cos(%{{.*}}) : (f32) -> f32
-  // CHECK: llvm.fdiv %[[SIN]], %[[COS]] : f32
+  // CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32
   %0 = spirv.GL.Tan %arg0 : f32
   spirv.Return
 }
@@ -175,13 +173,7 @@ spirv.func @tan(%arg0: f32) "None" {
 
 // CHECK-LABEL: @tanh
 spirv.func @tanh(%arg0: f32) "None" {
-  // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : f32
-  // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : f32
-  // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[X2]]) : (f32) -> f32
-  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
-  // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
-  // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : f32
-  // CHECK: llvm.fdiv %[[T0]], %[[T1]] : f32
+  // CHECK: llvm.intr.tanh(%{{.*}}) : (f32) -> f32
   %0 = spirv.GL.Tanh %arg0 : f32
   spirv.Return
 }

@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7079 tests passed
  • 594 tests skipped

@hankluo6
Copy link
Contributor Author

@kuhar could you please help merge this? Thank you.

@kuhar kuhar merged commit 27231bc into llvm:main Nov 18, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] Inconsistent output when executing MLIR program with and without -convert-math-to-spirv

3 participants