Skip to content

Commit 4ff1fd6

Browse files
authored
[IR] Add convenience builder function for program id ops� (#4855)
1 parent 518b26e commit 4ff1fd6

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
602602

603603
let assemblyFormat = "$axis attr-dict `:` type($result)";
604604

605+
let builders = [
606+
OpBuilder<(ins "int":$axis), [{
607+
build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis)));
608+
}]>
609+
];
610+
605611
let extraClassDeclaration = [{
606612
int32_t getAxisAsInt() {
607613
return static_cast<int32_t>(getAxis());
@@ -615,6 +621,11 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
615621
let results = (outs I32:$result);
616622

617623
let assemblyFormat = "$axis attr-dict `:` type($result)";
624+
let builders = [
625+
OpBuilder<(ins "int":$axis), [{
626+
build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis)));
627+
}]>
628+
];
618629

619630
let extraClassDeclaration = [{
620631
int32_t getAxisAsInt() {

python/src/ir.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,19 +1417,13 @@ void init_triton_ir(py::module &&m) {
14171417
[](TritonOpBuilder &self, int axis) -> Value {
14181418
if (axis < 0 || axis > 3)
14191419
throw pybind11::index_error("program_id must be in [0,3]");
1420-
return self.create<GetProgramIdOp>(
1421-
self.getBuilder().getI32Type(),
1422-
ProgramIDDimAttr::get(self.getBuilder().getContext(),
1423-
ProgramIDDim(axis)));
1420+
return self.create<GetProgramIdOp>(axis);
14241421
})
14251422
.def("create_get_num_programs",
14261423
[](TritonOpBuilder &self, int axis) -> Value {
14271424
if (axis < 0 || axis > 3)
14281425
throw pybind11::index_error("program_id must be in [0,3]");
1429-
return self.create<GetNumProgramsOp>(
1430-
self.getBuilder().getI32Type(),
1431-
ProgramIDDimAttr::get(self.getBuilder().getContext(),
1432-
ProgramIDDim(axis)));
1426+
return self.create<GetNumProgramsOp>(axis);
14331427
})
14341428
.def("create_dot",
14351429
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,

0 commit comments

Comments
 (0)