@@ -97,6 +97,7 @@ struct GluonLayouts {
9797 py::handle NVMMASharedLayout;
9898 py::handle SwizzledSharedLayout;
9999 py::handle AMDMFMALayout;
100+ py::handle AMDWMMALayout;
100101 py::handle PaddedSharedLayout;
101102 py::handle GluonDType;
102103
@@ -117,6 +118,7 @@ struct GluonLayouts {
117118 SwizzledSharedLayout =
118119 py::object (layouts.attr (" SwizzledSharedLayout" )).release ();
119120 AMDMFMALayout = py::object (amdLayouts.attr (" AMDMFMALayout" )).release ();
121+ AMDWMMALayout = py::object (amdLayouts.attr (" AMDWMMALayout" )).release ();
120122 PaddedSharedLayout =
121123 py::object (layouts.attr (" PaddedSharedLayout" )).release ();
122124
@@ -226,6 +228,14 @@ py::object layoutToGluon(Attribute layout) {
226228 toStdVector (ctaLayout.getCTAsPerCGA ()),
227229 toStdVector (ctaLayout.getCTASplitNum ()),
228230 toStdVector (ctaLayout.getCTAOrder ()));
231+ } else if (auto amdWmma = dyn_cast<ttg::AMDWmmaEncodingAttr>(layout)) {
232+ auto ctaLayout = amdWmma.getCTALayout ();
233+ return layouts.AMDWMMALayout (amdWmma.getVersion (),
234+ amdWmma.getIsTransposed (),
235+ toStdVector (amdWmma.getWarpsPerCTA ()),
236+ toStdVector (ctaLayout.getCTAsPerCGA ()),
237+ toStdVector (ctaLayout.getCTASplitNum ()),
238+ toStdVector (ctaLayout.getCTAOrder ()));
229239 } else if (auto paddedShared =
230240 dyn_cast<ttg::PaddedSharedEncodingAttr>(layout)) {
231241 auto *ctx = paddedShared.getContext ();
@@ -357,6 +367,18 @@ void init_gluon_ir(py::module &&m) {
357367 ctx, version, warpsPerCta, tilesPerWarp, instrShape[0 ],
358368 instrShape[1 ], transposed, ctaLayout, elemType);
359369 })
370+ .def (" get_amd_wmma_layout" ,
371+ [](GluonOpBuilder &self, unsigned version, bool transposed,
372+ std::vector<unsigned > &warpsPerCta,
373+ std::vector<unsigned > &ctasPerCga,
374+ std::vector<unsigned > &ctaSplitNum,
375+ std::vector<unsigned > &ctaOrder) -> Attribute {
376+ auto ctx = self.getContext ();
377+ auto ctaLayout = self.getChecked <ttg::CTALayoutAttr>(
378+ ctx, ctasPerCga, ctaSplitNum, ctaOrder);
379+ return ttg::AMDWmmaEncodingAttr::get (ctx, version, transposed,
380+ warpsPerCta, ctaLayout);
381+ })
360382 .def (" get_padded_shared_layout" ,
361383 [](GluonOpBuilder &self, std::vector<unsigned > &intervals,
362384 std::vector<unsigned > &paddings,
0 commit comments