@@ -149,6 +149,62 @@ inline int64_t getTMAContigDim(gpu::MemDescType memDescType) {
149149 return getTMAContigDim (memDescType.getEncoding (), memDescType.getShape ());
150150}
151151
152+ inline std::optional<int > getTMASwizzleMode (Operation *op, TensorDescType ty) {
153+ auto encoding = ty.getBlockType ().getEncoding ();
154+ auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
155+ unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth () : 0 ;
156+ if (!mmaEncoding) {
157+ auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(encoding);
158+ if (!swizzledEnc || swizzledEnc.getVec () != 1 ||
159+ swizzledEnc.getPerPhase () != 1 || swizzledEnc.getMaxPhase () != 1 ) {
160+ if (op)
161+ op->emitError (" Unhandled encoding type" );
162+ return std::nullopt ;
163+ }
164+ }
165+
166+ bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded ();
167+ assert (!fp4Padded || swizzleBytes == 128 &&
168+ " elem type .b4x16_p64 supports only 128B swizzling" );
169+
170+ int32_t swizzleMode = 0 ;
171+ if (swizzleBytes == 128 ) {
172+ swizzleMode = 3 ;
173+ } else if (swizzleBytes == 64 ) {
174+ swizzleMode = 2 ;
175+ } else if (swizzleBytes == 32 ) {
176+ swizzleMode = 1 ;
177+ }
178+ return swizzleMode;
179+ }
180+
181+ inline std::optional<int > getTMAElementType (Operation *op, TensorDescType ty) {
182+ auto encoding = ty.getBlockType ().getEncoding ();
183+ auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
184+ bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded ();
185+
186+ if (fp4Padded)
187+ return 14 ; // .b4x16_p64
188+
189+ auto elemSize = ty.getBlockType ().getElementTypeBitWidth () / 8 ;
190+ switch (elemSize) {
191+ case 1 :
192+ return 0 ;
193+ case 2 :
194+ return 1 ;
195+ case 4 :
196+ return 2 ;
197+ default :
198+ break ;
199+ }
200+ if (op) {
201+ op->emitError ()
202+ << " Tensor descriptor element type must have size 1, 2, or 4 but got "
203+ << elemSize;
204+ }
205+ return std::nullopt ;
206+ }
207+
152208template <typename BuilderT>
153209mlir::LogicalResult createTMADesc (mlir::Value tmaPtr,
154210 mlir::triton::MakeTensorDescOp op,
@@ -182,8 +238,6 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
182238 boxDim.push_back (mkI32Constant (shapePerCTA[k]));
183239
184240 unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth () : 0 ;
185- assert (!fp4Padded || swizzleBytes == 128 &&
186- " elem type .b4x16_p64 supports only 128B swizzling" );
187241 if (!mmaEncoding) {
188242 auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(
189243 op.getType ().getBlockType ().getEncoding ());
@@ -194,14 +248,10 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
194248 }
195249 }
196250
197- int32_t swizzle_mode = 0 ;
198- if (swizzleBytes == 128 ) {
199- swizzle_mode = 3 ;
200- } else if (swizzleBytes == 64 ) {
201- swizzle_mode = 2 ;
202- } else if (swizzleBytes == 32 ) {
203- swizzle_mode = 1 ;
204- }
251+ auto maybeSwizzleMode = getTMASwizzleMode (op, op.getType ());
252+ if (!maybeSwizzleMode)
253+ return failure ();
254+ auto swizzleMode = *maybeSwizzleMode;
205255
206256 Value elemSizeVal = builder.template create <arith::ConstantOp>(
207257 loc, builder.getI64Type (), builder.getI64IntegerAttr (elemSize));
@@ -224,31 +274,9 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
224274 globalStride[i] = builder.template create <arith::MulIOp>(
225275 loc, globalStride[i], elemSizeVal);
226276
227- int elemTypeEnum;
228-
229- if (fp4Padded) {
230- elemTypeEnum = 14 ; // .b4x16_p64
231- } else {
232- switch (elemSize) {
233- case 1 : {
234- elemTypeEnum = 0 ;
235- break ;
236- }
237- case 2 : {
238- elemTypeEnum = 1 ;
239- break ;
240- }
241- case 4 : {
242- elemTypeEnum = 2 ;
243- break ;
244- }
245- default : {
246- op->emitError ()
247- << " Tensor descriptor element type must have size 1, 2, or 4 but got "
248- << elemSize;
249- return failure ();
250- }
251- }
277+ auto elemTypeEnum = getTMAElementType (op, op.getType ());
278+ if (!elemTypeEnum) {
279+ return failure ();
252280 }
253281
254282 builder.template create <triton::ExperimentalTensormapCreateOp>(
@@ -259,9 +287,9 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
259287 /* global_dim=*/ globalDim,
260288 /* global_stride=*/ globalStride,
261289 /* element_strides=*/ elementStride,
262- /* elem_type*/ builder.getI32IntegerAttr (elemTypeEnum),
290+ /* elem_type*/ builder.getI32IntegerAttr (* elemTypeEnum),
263291 /* interleave_layout*/ builder.getI32IntegerAttr (0 ),
264- /* swizzle_mode=*/ builder.getI32IntegerAttr (swizzle_mode ),
292+ /* swizzle_mode=*/ builder.getI32IntegerAttr (swizzleMode ),
265293 /* fill_mode=*/ builder.getI32IntegerAttr (0 ));
266294 return success ();
267295}
0 commit comments