Skip to content

Commit 3f4acb1

Browse files
Set FastMathFlags by default
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8bb917f commit 3f4acb1

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def make_llir(src, metadata, options):
261261
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
262262
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
263263
intel.passes.ttgpuir.add_to_llvmir(pm)
264+
intel.set_fast_math(mod)
264265
passes.convert.add_arith_to_llvmir(pm)
265266
passes.common.add_canonicalizer(pm)
266267
passes.common.add_cse(pm)

third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,6 @@ class ScheduleLoadPass
9090
op->moveAfter(def);
9191
}
9292
});
93-
94-
// HoHo, add fastmath for all
95-
// may do this after llvm ir according to user fmath flag
96-
mod.walk([&](Operation *op) {
97-
if (auto fmIf = dyn_cast<arith::ArithFastMathInterface>(op))
98-
op->setAttr(
99-
fmIf.getFastMathAttrName(),
100-
arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast));
101-
});
10293
}
10394

10495
private:

third_party/intel/triton_xpu.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,17 @@ void init_triton_intel(py::module &&m) {
215215
context.loadAllAvailableDialects();
216216
});
217217

218+
m.def("set_fast_math", [](mlir::ModuleOp mod) {
219+
using namespace mlir;
220+
MLIRContext *ctx = mod.getContext();
221+
mod.walk([&](Operation *op) {
222+
if (auto fmIf = dyn_cast<arith::ArithFastMathInterface>(op))
223+
op->setAttr(
224+
fmIf.getFastMathAttrName(),
225+
arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast));
226+
});
227+
});
228+
218229
m.def("set_spv_target_triple", [](llvm::Module *mod) {
219230
std::string triple = "spir64-unknown-unknown";
220231
std::string layout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"

0 commit comments

Comments
 (0)