Skip to content

Commit e72adf2

Browse files
Move get_threads_per_warp to init_triton_intel (#2932)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 544b879 commit e72adf2

File tree

3 files changed

+7
-16
lines changed

3 files changed

+7
-16
lines changed

python/src/ir.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,6 @@ void outputWarning(Location loc, const std::string &msg) {
138138

139139
} // anonymous namespace
140140

141-
/*****************************************************************************/
142-
/* Python bindings for triton::ir::ttgir */
143-
/*****************************************************************************/
144-
145-
void init_triton_ttgpuir(py::module &&m) {
146-
m.def("get_threads_per_warp", [](mlir::ModuleOp &mod) -> py::object {
147-
auto ret = mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
148-
return py::int_(ret);
149-
});
150-
}
151-
152141
/*****************************************************************************/
153142
/* Python bindings for ir */
154143
/*****************************************************************************/
@@ -1794,9 +1783,6 @@ void init_triton_ir(py::module &&m) {
17941783
if (failed(self.run(mod.getOperation())))
17951784
throw std::runtime_error("PassManager::run failed");
17961785
});
1797-
1798-
// ttgpu dialect bindings.
1799-
init_triton_ttgpuir(m.def_submodule("ttgpuir"));
18001786
}
18011787

18021788
void init_triton_env_vars(py::module &m) {

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def make_ttgir(mod, metadata, opt, properties):
228228
pm.run(mod)
229229

230230
# Overwrite the threads_per_warp option with the module annotation.
231-
opt.threads_per_warp = ir.ttgpuir.get_threads_per_warp(mod)
231+
opt.threads_per_warp = intel.get_threads_per_warp(mod)
232232

233233
# Run the TTIR -> TTGIR pass pipeline.
234234
pm = ir.pass_manager(mod.context)
@@ -271,7 +271,7 @@ def make_llir(src, metadata, options):
271271
num_warp_groups = src.get_int_attr("ttg.num-warp-groups-per-cta")
272272
if num_warp_groups is not None:
273273
metadata["num_warps"] *= num_warp_groups
274-
threads_per_warp = ir.ttgpuir.get_threads_per_warp(src)
274+
threads_per_warp = intel.get_threads_per_warp(src)
275275
metadata["threads_per_warp"] = threads_per_warp
276276
mod = src
277277
# TritonGPU -> LLVM-IR (MLIR)

third_party/intel/triton_xpu.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ void init_triton_intel(py::module &&m) {
234234
context.loadAllAvailableDialects();
235235
});
236236

237+
m.def("get_threads_per_warp", [](mlir::ModuleOp &mod) -> py::object {
238+
auto ret = mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
239+
return py::int_(ret);
240+
});
241+
237242
// May do this after llvm ir according to user fmath flag.
238243
m.def("set_fast_math", [](mlir::ModuleOp mod) {
239244
using namespace mlir;

0 commit comments

Comments
 (0)