@@ -20,9 +20,9 @@ namespace mlir::triton::gpu {
2020
2121void decomposeTensorCoreToDotLayoutConversion (ModuleOp module ,
2222 ShortcutFn shortcutFn) {
23- int numWarps = triton::gpu::TritonGPUDialect::getNumWarps ( module );
24- int numCTAs = triton::gpu:: TritonGPUDialect::getNumCTAs (module );
25- int threadsPerWarp = triton::gpu:: TritonGPUDialect::getThreadsPerWarp (module );
23+ MLIRContext *ctx = module . getContext ( );
24+ int numCTAs = TritonGPUDialect::getNumCTAs (module );
25+ int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp (module );
2626
2727 module .walk ([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
2828 OpBuilder builder (cvtOp);
@@ -31,28 +31,32 @@ void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
3131 auto srcMma = dyn_cast<MmaEncodingTrait>(srcType.getEncoding ());
3232 auto dstDotOp =
3333 dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding ());
34- if (srcMma && dstDotOp && !shortcutFn (srcType, dstType)) {
35- auto tmpType = RankedTensorType::get (
36- dstType.getShape (), dstType.getElementType (),
37- triton::gpu::BlockedEncodingAttr::get (
38- module .getContext (), srcType.getShape (), getSizePerThread (srcMma),
39- getOrder (srcMma), numWarps, threadsPerWarp, numCTAs));
40- auto tmp = builder.create <triton::gpu::ConvertLayoutOp>(
41- cvtOp.getLoc (), tmpType, cvtOp.getSrc ());
42- addAttrs (tmp, cvtOp->getAttrs ());
43- auto newConvert = builder.create <triton::gpu::ConvertLayoutOp>(
44- cvtOp.getLoc (), dstType, tmp);
45- addAttrs (newConvert, cvtOp->getAttrs ());
46- cvtOp.replaceAllUsesWith (newConvert.getResult ());
47- cvtOp.erase ();
48- }
34+ if (!srcMma || !dstDotOp || shortcutFn (srcType, dstType))
35+ return ;
36+
37+ int numWarps = lookupNumWarps (cvtOp);
38+ auto enc = BlockedEncodingAttr::get (
39+ ctx, srcType.getShape (), getSizePerThread (srcMma), getOrder (srcMma),
40+ numWarps, threadsPerWarp, numCTAs);
41+ auto tmpType = RankedTensorType::get (dstType.getShape (),
42+ dstType.getElementType (), enc);
43+
44+ auto tmp = builder.create <ConvertLayoutOp>(cvtOp.getLoc (), tmpType,
45+ cvtOp.getSrc ());
46+ addAttrs (tmp, cvtOp->getAttrs ());
47+ auto newConvert =
48+ builder.create <ConvertLayoutOp>(cvtOp.getLoc (), dstType, tmp);
49+ addAttrs (newConvert, cvtOp->getAttrs ());
50+
51+ cvtOp.replaceAllUsesWith (newConvert.getResult ());
52+ cvtOp.erase ();
4953 });
5054}
5155
5256void decomposeBlockedToDotLayoutConversion (ModuleOp module ) {
53- int numWarps = triton::gpu::TritonGPUDialect::getNumWarps (module );
5457 int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs (module );
5558 int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp (module );
59+
5660 module .walk ([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
5761 OpBuilder builder (cvtOp);
5862 auto srcType = cast<RankedTensorType>(cvtOp.getSrc ().getType ());
0 commit comments