Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/TritonIntelGPU/schedule-load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
scf.yield %98, %106, %110, %114, %118, %122, %126, %130, %321, %322, %323, %324, %325, %326, %327, %328, %329, %330, %331, %332, %333, %334, %335, %336 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>, !tt.ptr<tensor<16x16xf16>>
}
%cst_1 = arith.constant dense<1.000000e+00> : tensor<8x16xf32>
// CHECK-COUNT-8: arith.divf {{.*}} fastmath<fast>
%67 = arith.divf %62#1, %cst_1 : tensor<8x16xf32>
%68 = arith.divf %62#2, %cst_1 : tensor<8x16xf32>
%69 = arith.divf %62#3, %cst_1 : tensor<8x16xf32>
Expand Down
1 change: 1 addition & 0 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def make_llir(src, metadata, options):
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
intel.passes.ttgpuir.add_to_llvmir(pm)
intel.set_fast_math(mod)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ class ScheduleLoadPass
op->moveAfter(def);
}
});

// HoHo, add fastmath for all
// may do this after llvm ir according to user fmath flag
mod.walk([&](Operation *op) {
if (auto fmIf = dyn_cast<arith::ArithFastMathInterface>(op))
op->setAttr(
fmIf.getFastMathAttrName(),
arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast));
});
}

private:
Expand Down
11 changes: 11 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ void init_triton_intel(py::module &&m) {
context.loadAllAvailableDialects();
});

m.def("set_fast_math", [](mlir::ModuleOp mod) {
using namespace mlir;
MLIRContext *ctx = mod.getContext();
mod.walk([&](Operation *op) {
if (auto fmIf = dyn_cast<arith::ArithFastMathInterface>(op))
op->setAttr(
fmIf.getFastMathAttrName(),
arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast));
});
});

m.def("set_spv_target_triple", [](llvm::Module *mod) {
std::string triple = "spir64-unknown-unknown";
std::string layout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"
Expand Down