@@ -86,6 +86,7 @@ struct GluonLayouts {
8686 py::handle BlockedLayout;
8787 py::handle SliceLayout;
8888 py::handle DistributedLinearLayout;
89+ py::handle NVMMADistributedLayout;
8990 py::handle NVMMASharedLayout;
9091 py::handle SwizzledSharedLayout;
9192
@@ -96,6 +97,8 @@ struct GluonLayouts {
9697 SliceLayout = py::object (layouts.attr (" SliceLayout" )).release ();
9798 DistributedLinearLayout =
9899 py::object (layouts.attr (" DistributedLinearLayout" )).release ();
100+ NVMMADistributedLayout =
101+ py::object (layouts.attr (" NVMMADistributedLayout" )).release ();
99102 NVMMASharedLayout = py::object (layouts.attr (" NVMMASharedLayout" )).release ();
100103 SwizzledSharedLayout =
101104 py::object (layouts.attr (" SwizzledSharedLayout" )).release ();
@@ -131,6 +134,14 @@ py::object layoutToGluon(Attribute layout) {
131134 ll.getBases ().lookup (kReg ), ll.getBases ().lookup (kLane ),
132135 ll.getBases ().lookup (kWarp ), ll.getBases ().lookup (kBlock ),
133136 toStdVector (ArrayRef (llvm::to_vector (ll.getOutDimSizes ()))));
137+ } else if (auto mma = dyn_cast<ttg::NvidiaMmaEncodingAttr>(layout)) {
138+ auto ctaLayout = mma.getCTALayout ();
139+ return layouts.NVMMADistributedLayout (
140+ std::vector<unsigned >{mma.getVersionMajor (), mma.getVersionMinor ()},
141+ toStdVector (mma.getWarpsPerCTA ()),
142+ toStdVector (ctaLayout.getCTAsPerCGA ()),
143+ toStdVector (ctaLayout.getCTASplitNum ()),
144+ toStdVector (ctaLayout.getCTAOrder ()), toStdVector (mma.getInstrShape ()));
134145 } else if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(layout)) {
135146 auto ctaLayout = nvmma.getCTALayout ();
136147 return layouts.NVMMASharedLayout (
@@ -224,6 +235,20 @@ void init_gluon_ir(py::module &&m) {
224235 /* requiresSurjective=*/ true );
225236 return ttg::LinearEncodingAttr::get (ctx, ll);
226237 })
238+ .def (" get_mma_layout" ,
239+ [](GluonOpBuilder &self, std::vector<unsigned > &version,
240+ std::vector<unsigned > &warpsPerCta,
241+ std::vector<unsigned > &ctasPerCga,
242+ std::vector<unsigned > &ctaSplitNum,
243+ std::vector<unsigned > &ctaOrder,
244+ std::vector<unsigned > &instrShape) -> Attribute {
245+ auto ctx = self.getContext ();
246+ auto ctaLayout = self.getChecked <ttg::CTALayoutAttr>(
247+ ctx, ctasPerCga, ctaSplitNum, ctaOrder);
248+ return self.getChecked <ttg::NvidiaMmaEncodingAttr>(
249+ ctx, version[0 ], version[1 ], warpsPerCta, ctaLayout,
250+ instrShape);
251+ })
227252 .def (" get_nvmma_shared_layout" ,
228253 [](GluonOpBuilder &self, unsigned swizzleByteWidth,
229254 unsigned elementBitwidth, bool transposed, bool fp4Padded,
@@ -359,6 +384,14 @@ void init_gluon_ir(py::module &&m) {
359384 auto op = self.create <triton::SplitOp>(TypeRange{resTy, resTy}, a);
360385 return py::make_tuple (op->getResult (0 ), op->getResult (1 ));
361386 })
387+ .def (" create_warpgroup_mma" ,
388+ [](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc,
389+ triton::InputPrecision precision = triton::InputPrecision::IEEE,
390+ int maxNumImpreciseAcc = 0 , bool isAsync = false ) -> Value {
391+ return self.create <ttng::WarpGroupDotOp>(
392+ a, b, acc, useAcc, precision, maxNumImpreciseAcc, isAsync);
393+ })
394+
362395 .def (" create_tmem_alloc" ,
363396 [](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
364397 return self.create <ttng::TMEMAllocOp>(resultTy, value);
0 commit comments