@@ -82,6 +82,75 @@ struct GluonOpBuilder : public TritonOpBuilder {
82
82
}
83
83
};
84
84
85
+ struct GluonLayouts {
86
+ py::handle BlockedLayout;
87
+ py::handle SliceLayout;
88
+ py::handle DistributedLinearLayout;
89
+ py::handle NVMMASharedLayout;
90
+ py::handle SwizzledSharedLayout;
91
+
92
+ GluonLayouts () {
93
+ auto layouts =
94
+ py::module::import (" triton.experimental.gluon.language._layouts" );
95
+ BlockedLayout = py::object (layouts.attr (" BlockedLayout" )).release ();
96
+ SliceLayout = py::object (layouts.attr (" SliceLayout" )).release ();
97
+ DistributedLinearLayout =
98
+ py::object (layouts.attr (" DistributedLinearLayout" )).release ();
99
+ NVMMASharedLayout = py::object (layouts.attr (" NVMMASharedLayout" )).release ();
100
+ SwizzledSharedLayout =
101
+ py::object (layouts.attr (" SwizzledSharedLayout" )).release ();
102
+ }
103
+ };
104
+
105
+ template <typename T> std::vector<T> toStdVector (llvm::ArrayRef<T> array) {
106
+ return std::vector<T>(array.begin (), array.end ());
107
+ }
108
+
109
+ py::object layoutToGluon (Attribute layout) {
110
+ static GluonLayouts layouts;
111
+ if (auto blocked = dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
112
+ auto ctaLayout = blocked.getCTALayout ();
113
+ return layouts.BlockedLayout (toStdVector (blocked.getSizePerThread ()),
114
+ toStdVector (blocked.getThreadsPerWarp ()),
115
+ toStdVector (blocked.getWarpsPerCTA ()),
116
+ toStdVector (blocked.getOrder ()),
117
+ toStdVector (ctaLayout.getCTAsPerCGA ()),
118
+ toStdVector (ctaLayout.getCTASplitNum ()),
119
+ toStdVector (ctaLayout.getCTAOrder ()));
120
+ } else if (auto sliced = dyn_cast<ttg::SliceEncodingAttr>(layout)) {
121
+ return layouts.SliceLayout (sliced.getDim (),
122
+ layoutToGluon (sliced.getParent ()));
123
+ } else if (auto linear = dyn_cast<ttg::LinearEncodingAttr>(layout)) {
124
+ auto ll = linear.getLinearLayout ();
125
+ auto ctx = layout.getContext ();
126
+ auto kReg = mlir::StringAttr::get (ctx, " register" );
127
+ auto kLane = mlir::StringAttr::get (ctx, " lane" );
128
+ auto kWarp = mlir::StringAttr::get (ctx, " warp" );
129
+ auto kBlock = mlir::StringAttr::get (ctx, " block" );
130
+ return layouts.DistributedLinearLayout (
131
+ ll.getBases ().lookup (kReg ), ll.getBases ().lookup (kLane ),
132
+ ll.getBases ().lookup (kWarp ), ll.getBases ().lookup (kBlock ),
133
+ ll.getOutDimSizes ());
134
+ } else if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(layout)) {
135
+ auto ctaLayout = nvmma.getCTALayout ();
136
+ return layouts.NVMMASharedLayout (
137
+ nvmma.getSwizzlingByteWidth (), nvmma.getElementBitWidth (),
138
+ ctaLayout.getRank (), nvmma.getTransposed (), nvmma.getFp4Padded (),
139
+ toStdVector (ctaLayout.getCTAsPerCGA ()),
140
+ toStdVector (ctaLayout.getCTASplitNum ()),
141
+ toStdVector (ctaLayout.getCTAOrder ()));
142
+ } else if (auto swizzled =
143
+ dyn_cast<ttg::SwizzledSharedEncodingAttr>(layout)) {
144
+ auto ctaLayout = nvmma.getCTALayout ();
145
+ return layouts.SwizzledSharedLayout (
146
+ swizzled.getVec (), swizzled.getPerPhase (), swizzled.getMaxPhase (),
147
+ swizzled.getOrder (), toStdVector (ctaLayout.getCTAsPerCGA ()),
148
+ toStdVector (ctaLayout.getCTASplitNum ()),
149
+ toStdVector (ctaLayout.getCTAOrder ()));
150
+ }
151
+ throw py::value_error (" Unhandled encoding encountered" );
152
+ }
153
+
85
154
void init_gluon_ir (py::module &&m) {
86
155
using ret = py::return_value_policy;
87
156
@@ -189,6 +258,12 @@ void init_gluon_ir(py::module &&m) {
189
258
ctx, block[0 ], block[1 ], unpacked, ctaSplitNum[0 ],
190
259
ctaSplitNum[1 ]);
191
260
})
261
+ .def (" get_gluon_layout_from_tensor" ,
262
+ [](GluonOpBuilder &self, Value tensor) -> py::object {
263
+ auto ty = dyn_cast<RankedTensorType>(tensor.getType ());
264
+ assert (ty.getEncoding ());
265
+ return layoutToGluon (ty.getEncoding ());
266
+ })
192
267
.def (" create_convert_layout" ,
193
268
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
194
269
return self.create <ttg::ConvertLayoutOp>(resultTy, value);
0 commit comments