Skip to content

Commit 27231bc

Browse files
authored
[MLIR][SPIRV] Lower SPIR-V Tan/Tanh ops to LLVM intrinsics (#168419)
Fixed #148354 Lower SPIR-V Tan/Tanh ops using the corresponding LLVM intrinsics to reduce instructions and prevent overflow caused by the previous `exp`-based expansion.
1 parent 2432465 commit 27231bc

File tree

2 files changed

+6
-32
lines changed

2 files changed

+6
-32
lines changed

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,20 +1520,12 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
15201520
if (!dstType)
15211521
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
15221522

1523-
Location loc = tanOp.getLoc();
1524-
Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
1525-
Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
1526-
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1523+
rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
1524+
adaptor.getOperands());
15271525
return success();
15281526
}
15291527
};
15301528

1531-
/// Convert `spirv.Tanh` to
1532-
///
1533-
/// exp(2x) - 1
1534-
/// -----------
1535-
/// exp(2x) + 1
1536-
///
15371529
class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
15381530
public:
15391531
using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
@@ -1546,18 +1538,8 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
15461538
if (!dstType)
15471539
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
15481540

1549-
Location loc = tanhOp.getLoc();
1550-
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1551-
Value multiplied =
1552-
LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
1553-
Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
1554-
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1555-
Value numerator =
1556-
LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
1557-
Value denominator =
1558-
LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
1559-
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1560-
denominator);
1541+
rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
1542+
adaptor.getOperands());
15611543
return success();
15621544
}
15631545
};

mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ spirv.func @sqrt(%arg0: f32, %arg1: vector<3xf16>) "None" {
162162

163163
// CHECK-LABEL: @tan
164164
spirv.func @tan(%arg0: f32) "None" {
165-
// CHECK: %[[SIN:.*]] = llvm.intr.sin(%{{.*}}) : (f32) -> f32
166-
// CHECK: %[[COS:.*]] = llvm.intr.cos(%{{.*}}) : (f32) -> f32
167-
// CHECK: llvm.fdiv %[[SIN]], %[[COS]] : f32
165+
// CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32
168166
%0 = spirv.GL.Tan %arg0 : f32
169167
spirv.Return
170168
}
@@ -175,13 +173,7 @@ spirv.func @tan(%arg0: f32) "None" {
175173

176174
// CHECK-LABEL: @tanh
177175
spirv.func @tanh(%arg0: f32) "None" {
178-
// CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : f32
179-
// CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : f32
180-
// CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[X2]]) : (f32) -> f32
181-
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
182-
// CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
183-
// CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : f32
184-
// CHECK: llvm.fdiv %[[T0]], %[[T1]] : f32
176+
// CHECK: llvm.intr.tanh(%{{.*}}) : (f32) -> f32
185177
%0 = spirv.GL.Tanh %arg0 : f32
186178
spirv.Return
187179
}

0 commit comments

Comments
 (0)