Skip to content

Commit 84fe9da

Browse files
authored
Use fast math function for tl.math.log as exp (triton-lang#4723)
We were using precise log op by mistake. To get high precision user can use libdevice directly. Also clean up special case for math.exp
1 parent df26ec6 commit 84fe9da

File tree

4 files changed

+19
-41
lines changed

4 files changed

+19
-41
lines changed

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,7 @@ struct ElementwiseOpConversion
346346
ConversionPatternRewriter &rewriter,
347347
Type elemTy, MultipleOperandsRange operands,
348348
Location loc) const {
349-
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
350-
adaptor.getAttributes().getValue())};
349+
return {rewriter.create<DestOp>(loc, elemTy, operands[0], op->getAttrs())};
351350
}
352351
};
353352

python/src/ir.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <pybind11/functional.h>
1+
#include <pybind11/functional.h>
22
#include <pybind11/pybind11.h>
33
#include <pybind11/stl.h>
44

@@ -135,6 +135,11 @@ void outputWarning(Location loc, const std::string &msg) {
135135
/*stack_level=*/2);
136136
}
137137

138+
template <typename OpTy> OpTy approxMath(OpTy op) {
139+
op.setFastmath(arith::FastMathFlags::afn);
140+
return op;
141+
}
142+
138143
} // anonymous namespace
139144

140145
/*****************************************************************************/
@@ -1447,27 +1452,27 @@ void init_triton_ir(py::module &&m) {
14471452
})
14481453
.def("create_exp",
14491454
[](TritonOpBuilder &self, Value &val) -> Value {
1450-
return self.create<math::ExpOp>(val);
1455+
return approxMath(self.create<math::ExpOp>(val));
14511456
})
14521457
.def("create_exp2",
14531458
[](TritonOpBuilder &self, Value &val) -> Value {
1454-
return self.create<math::Exp2Op>(val);
1459+
return approxMath(self.create<math::Exp2Op>(val));
14551460
})
14561461
.def("create_cos",
14571462
[](TritonOpBuilder &self, Value &val) -> Value {
1458-
return self.create<math::CosOp>(val);
1463+
return approxMath(self.create<math::CosOp>(val));
14591464
})
14601465
.def("create_sin",
14611466
[](TritonOpBuilder &self, Value &val) -> Value {
1462-
return self.create<math::SinOp>(val);
1467+
return approxMath(self.create<math::SinOp>(val));
14631468
})
14641469
.def("create_log",
14651470
[](TritonOpBuilder &self, Value &val) -> Value {
1466-
return self.create<math::LogOp>(val);
1471+
return approxMath(self.create<math::LogOp>(val));
14671472
})
14681473
.def("create_log2",
14691474
[](TritonOpBuilder &self, Value &val) -> Value {
1470-
return self.create<math::Log2Op>(val);
1475+
return approxMath(self.create<math::Log2Op>(val));
14711476
})
14721477
.def("create_erf",
14731478
[](TritonOpBuilder &self, Value &val) -> Value {

python/test/unit/language/test_core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4376,9 +4376,14 @@ def kernel(X, Y, BLOCK: tl.constexpr):
43764376
x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device))
43774377
y = torch.zeros(shape, dtype=torch.float32, device=device)
43784378

4379-
kernel[(1, )](x, y, BLOCK=shape[0])
4379+
k = kernel[(1, )](x, y, BLOCK=shape[0])
43804380
torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3)
43814381

4382+
if func_str in ['log', 'log2'] and is_cuda():
4383+
assert 'lg2.approx.ftz.f32' in k.asm['ptx']
4384+
if func_str in ['exp', 'exp2'] and is_cuda():
4385+
assert 'ex2.approx.ftz.f32' in k.asm['ptx']
4386+
43824387

43834388
# -----------------------
43844389
# test inline asm

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -755,32 +755,6 @@ struct TruncFOpConversion
755755
}
756756
};
757757

758-
struct ExpOpConversionApprox
759-
: ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox> {
760-
using Base = ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox>;
761-
using Base::Base;
762-
using Adaptor = typename Base::OpAdaptor;
763-
764-
SmallVector<Value> createDestOps(math::ExpOp op, OpAdaptor adaptor,
765-
ConversionPatternRewriter &rewriter,
766-
Type elemTy, MultipleOperandsRange operands,
767-
Location loc) const {
768-
// For non-FP32 input, call __nv_expf for higher-precision calculation
769-
if (elemTy.getIntOrFloatBitWidth() != 32)
770-
return {};
771-
772-
const double log2e = 1.4426950408889634;
773-
Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e));
774-
775-
PTXBuilder ptxBuilder;
776-
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
777-
auto output = ptxBuilder.newOperand("=f");
778-
auto input = ptxBuilder.newOperand(prod, "f");
779-
exp2(output, input);
780-
return {ptxBuilder.launch(rewriter, loc, f32_ty, false)};
781-
}
782-
};
783-
784758
struct ClampFOpConversion
785759
: ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion> {
786760
using Base = ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion>;
@@ -951,11 +925,6 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns(
951925
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
952926
computeCapability, benefit);
953927

954-
// ExpOpConversionApprox will try using ex2.approx if the input type is
955-
// FP32. For other input types, ExpOpConversionApprox will return failure and
956-
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
957-
// __nv_expf for higher-precision calculation
958-
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
959928
bool hwNanPropagationSupported = computeCapability >= 80;
960929
mlir::triton::populateMinMaxFOpToLLVMPattern(
961930
typeConverter, patterns, axisInfoAnalysis, hwNanPropagationSupported,

0 commit comments

Comments
 (0)