diff --git a/test/TritonIntelGPU/schedule-load.mlir b/test/TritonIntelGPU/schedule-load.mlir index 6352b6d033..6053ff4ed3 100644 --- a/test/TritonIntelGPU/schedule-load.mlir +++ b/test/TritonIntelGPU/schedule-load.mlir @@ -148,7 +148,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war scf.yield %98, %106, %110, %114, %118, %122, %126, %130, %321, %322, %323, %324, %325, %326, %327, %328, %329, %330, %331, %332, %333, %334, %335, %336 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr> } %cst_1 = arith.constant dense<1.000000e+00> : tensor<8x16xf32> - // CHECK-COUNT-8: arith.divf {{.*}} fastmath %67 = arith.divf %62#1, %cst_1 : tensor<8x16xf32> %68 = arith.divf %62#2, %cst_1 : tensor<8x16xf32> %69 = arith.divf %62#3, %cst_1 : tensor<8x16xf32> diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 32c8c2e133..f5afb391d3 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -261,6 +261,7 @@ def make_llir(src, metadata, options): if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1": intel.passes.ttgpuir.add_allocate_shared_memory(pm) intel.passes.ttgpuir.add_to_llvmir(pm) + intel.set_fast_math(mod) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp index 41e975f9f9..b6ed435921 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp @@ -90,15 +90,6 @@ class ScheduleLoadPass op->moveAfter(def); } }); - - // HoHo, add fastmath for all - // may do this after llvm ir according to user fmath flag - mod.walk([&](Operation *op) { - if (auto fmIf = dyn_cast(op)) - op->setAttr( - fmIf.getFastMathAttrName(), - arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast)); - }); } private: diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index eb7be1c080..951de6ce35 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -215,6 +215,18 @@ void init_triton_intel(py::module &&m) { context.loadAllAvailableDialects(); }); + // May do this after llvm ir according to user fmath flag. + m.def("set_fast_math", [](mlir::ModuleOp mod) { + using namespace mlir; + MLIRContext *ctx = mod.getContext(); + mod.walk([&](Operation *op) { + if (auto fmIf = dyn_cast(op)) + op->setAttr( + fmIf.getFastMathAttrName(), + arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast)); + }); + }); + m.def("set_spv_target_triple", [](llvm::Module *mod) { std::string triple = "spir64-unknown-unknown"; std::string layout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"