Skip to content

Commit 831d42b

Browse files
Set FastMathFlags by default
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 48e661d commit 831d42b

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

test/TritonIntelGPU/schedule-load.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
148148
scf.yield %98, %106, %110, %114, %118, %122, %126, %130, %321, %322, %323, %324, %325, %326, %327, %328, %329, %330, %331, %332, %333, %334, %335, %336 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>
149149
}
150150
%cst_1 = arith.constant dense<1.000000e+00> : tensor<8x16xf32>
151-
// CHECK-COUNT-8: arith.divf {{.*}} fastmath<fast>
152151
%67 = arith.divf %62#1, %cst_1 : tensor<8x16xf32>
153152
%68 = arith.divf %62#2, %cst_1 : tensor<8x16xf32>
154153
%69 = arith.divf %62#3, %cst_1 : tensor<8x16xf32>

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def make_llir(src, metadata, options):
261261
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
262262
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
263263
intel.passes.ttgpuir.add_to_llvmir(pm)
264+
intel.set_fast_math(mod)
264265
passes.convert.add_arith_to_llvmir(pm)
265266
passes.common.add_canonicalizer(pm)
266267
passes.common.add_cse(pm)

third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,6 @@ class ScheduleLoadPass
9090
op->moveAfter(def);
9191
}
9292
});
93-
94-
// HoHo, add fastmath for all
95-
// may do this after llvm ir according to user fmath flag
96-
mod.walk([&](Operation *op) {
97-
if (auto fmIf = dyn_cast<arith::ArithFastMathInterface>(op))
98-
op->setAttr(
99-
fmIf.getFastMathAttrName(),
100-
arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast));
101-
});
10293
}
10394

10495
private:

third_party/intel/triton_xpu.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,17 @@ void init_triton_intel(py::module &&m) {
215215
context.loadAllAvailableDialects();
216216
});
217217

218+
m.def("set_fast_math", [](mlir::ModuleOp mod) {
219+
using namespace mlir;
220+
MLIRContext *ctx = mod.getContext();
221+
mod.walk([&](Operation *op) {
222+
if (auto fmIf = dyn_cast<arith::ArithFastMathInterface>(op))
223+
op->setAttr(
224+
fmIf.getFastMathAttrName(),
225+
arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast));
226+
});
227+
});
228+
218229
m.def("set_spv_target_triple", [](llvm::Module *mod) {
219230
std::string triple = "spir64-unknown-unknown";
220231
std::string layout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"

0 commit comments

Comments
 (0)