66#include " mlir/IR/Types.h"
77#include " triton/Dialect/TritonGPU/IR/Attributes.h"
88#include " triton/Dialect/TritonGPU/IR/Dialect.h"
9+ #include " triton/Dialect/TritonGPU/IR/Types.h"
910#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1011
1112using namespace mlir ;
@@ -80,6 +81,17 @@ void init_gluon_ir(py::module &&m) {
8081 ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded,
8182 ctaLayout);
8283 })
84+ .def (" get_swizzled_shared_layout" ,
85+ [](GluonOpBuilder &self, int vec, int perPhase, int maxPhase,
86+ std::vector<unsigned > &order, std::vector<unsigned > &ctasPerCga,
87+ std::vector<unsigned > &ctaSplitNum,
88+ std::vector<unsigned > &ctaOrder) -> Attribute {
89+ auto ctx = self.getContext ();
90+ auto ctaLayout = ttg::CTALayoutAttr::get (ctx, ctasPerCga,
91+ ctaSplitNum, ctaOrder);
92+ return ttg::SwizzledSharedEncodingAttr::get (
93+ ctx, vec, perPhase, maxPhase, order, ctaLayout);
94+ })
8395 .def (" get_tensor_memory_layout" ,
8496 [](GluonOpBuilder &self, std::vector<unsigned > &block, bool unpacked,
8597 std::vector<unsigned > &ctaSplitNum) -> Attribute {
@@ -94,6 +106,10 @@ void init_gluon_ir(py::module &&m) {
94106 [](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
95107 return self.create <ttg::ConvertLayoutOp>(resultTy, value);
96108 })
109+ .def (" create_local_alloc" ,
110+ [](GluonOpBuilder &self, Type resultTy) -> Value {
111+ return self.create <ttg::LocalAllocOp>(resultTy);
112+ })
97113 .def (" create_local_alloc" ,
98114 [](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
99115 return self.create <ttg::LocalAllocOp>(resultTy, value);
@@ -106,10 +122,19 @@ void init_gluon_ir(py::module &&m) {
106122 [](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
107123 return self.create <ttg::LocalLoadOp>(resultTy, memDesc);
108124 })
125+ .def (" create_local_dealloc" ,
126+ [](GluonOpBuilder &self, Value memDesc) -> Operation * {
127+ return self.create <ttg::LocalDeallocOp>(memDesc);
128+ })
129+
109130 .def (" create_tmem_alloc" ,
110131 [](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
111132 return self.create <ttng::TMEMAllocOp>(resultTy, value);
112133 })
134+ .def (" create_tmem_alloc" ,
135+ [](GluonOpBuilder &self, Type resultTy, py::none value) -> Value {
136+ return self.create <ttng::TMEMAllocOp>(resultTy, Value{});
137+ })
113138 .def (" create_tmem_store" ,
114139 [](GluonOpBuilder &self, Value memDesc, Value value, Value pred) {
115140 self.create <ttng::TMEMStoreOp>(memDesc, value, pred);
@@ -123,6 +148,38 @@ void init_gluon_ir(py::module &&m) {
123148 int N) -> Value {
124149 return self.create <ttng::TMEMSubSliceOp>(resultTy, memDesc, N);
125150 })
151+ .def (" create_mbarrier_init" ,
152+ [](GluonOpBuilder &self, Value memDesc, int count) {
153+ self.create <ttng::InitBarrierOp>(memDesc, count);
154+ })
155+ .def (" create_mbarrier_inval" ,
156+ [](GluonOpBuilder &self, Value memDesc) {
157+ self.create <ttng::InvalBarrierOp>(memDesc);
158+ })
159+ .def (" create_mbarrier_expect" ,
160+ [](GluonOpBuilder &self, Value memDesc, int bytes, Value pred) {
161+ self.create <ttng::BarrierExpectOp>(memDesc, bytes, pred);
162+ })
163+ .def (" create_mbarrier_wait" ,
164+ [](GluonOpBuilder &self, Value memDesc, Value phase, Value pred,
165+ std::vector<Value> &deps) {
166+ self.create <ttng::WaitBarrierOp>(memDesc, phase, pred, deps);
167+ })
168+ .def (" create_mbarrier_arrive" ,
169+ [](GluonOpBuilder &self, Value memDesc, int count, Value pred) {
170+ self.create <ttng::ArriveBarrierOp>(memDesc, count, pred);
171+ })
172+ .def (" create_tcgen05_mma" ,
173+ [](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc,
174+ Value pred, std::vector<Value> &mbarriers,
175+ std::vector<Value> &mbarrier_preds) {
176+ Value accDep;
177+ bool two_ctas = false ;
178+ auto tokType = self.getBuilder ().getType <ttg::AsyncTokenType>();
179+ self.create <ttng::TCGen5MMAOp>(tokType, a, b, acc, accDep, useAcc,
180+ pred, two_ctas, mbarriers,
181+ mbarrier_preds);
182+ })
126183 .def (" create_warp_return" ,
127184 [](GluonOpBuilder &self) -> Operation * {
128185 return self.create <ttg::WarpReturnOp>();
0 commit comments