@@ -59,6 +59,59 @@ bool isOneOperandElementwiseOp(Operation *op) {
5959 return false ;
6060}
6161
62+ static triton::StoreOp convertMfmaLayoutForCDNA4 (PatternRewriter &rewriter,
63+ Value ptr, Value val,
64+ Value mask,
65+ triton::StoreOp oldStOp) {
66+ auto ptrType = cast<RankedTensorType>(ptr.getType ());
67+ auto valType = cast<RankedTensorType>(val.getType ());
68+
69+ auto mfmaLayout =
70+ cast<triton::gpu::AMDMfmaEncodingAttr>(valType.getEncoding ());
71+
72+ bool mfma32 = mfmaLayout.getMDim () == 32 && mfmaLayout.getNDim () == 32 ;
73+
74+ if (valType.getRank () != 2 ||
75+ (!valType.getElementType ().isF16 () &&
76+ !valType.getElementType ().isBF16 ()) ||
77+ mfmaLayout.getVersionMajor () != 4 || !mfmaLayout.getIsTransposed () ||
78+ !mfma32) {
79+ return rewriter.create <triton::StoreOp>(oldStOp.getLoc (), ptr, val, mask,
80+ oldStOp.getCache (),
81+ oldStOp.getEvict ());
82+ }
83+
84+ // Create a new layout where each thread holds 8 consecutive elements, in
85+ // order to enable wide 128-bit global stores.
86+ triton::LinearLayout mfma8Layout =
87+ chooseMfmaLikeStoreLayout (mfmaLayout, valType.getShape ());
88+
89+ Attribute newEncoding = triton::gpu::LinearEncodingAttr::get (
90+ mfmaLayout.getContext (), mfma8Layout);
91+ auto newPtrType = RankedTensorType::get (
92+ ptrType.getShape (), ptrType.getElementType (), newEncoding);
93+ Value newPtr = rewriter.create <triton::gpu::ConvertLayoutOp>(ptr.getLoc (),
94+ newPtrType, ptr);
95+
96+ auto newValType = RankedTensorType::get (
97+ valType.getShape (), valType.getElementType (), newEncoding);
98+ Value newVal = rewriter.create <triton::gpu::ConvertLayoutOp>(val.getLoc (),
99+ newValType, val);
100+
101+ Value newMask = mask;
102+ if (mask) {
103+ auto maskType = dyn_cast<RankedTensorType>(mask.getType ());
104+ auto newMaskType = RankedTensorType::get (
105+ maskType.getShape (), maskType.getElementType (), newEncoding);
106+ newMask = rewriter.create <triton::gpu::ConvertLayoutOp>(mask.getLoc (),
107+ newMaskType, mask);
108+ }
109+
110+ return rewriter.create <triton::StoreOp>(oldStOp.getLoc (), newPtr, newVal,
111+ newMask, oldStOp.getCache (),
112+ oldStOp.getEvict ());
113+ }
114+
62115// convert(val) : xmma -> blocked
63116// elementWiseOp(val) : blocked
64117// ...
@@ -126,19 +179,20 @@ class BypassEpilogueSMEM : public mlir::RewritePattern {
126179 auto newEncoding =
127180 cast<RankedTensorType>(cvtOp.getSrc ().getType ()).getEncoding ();
128181
129- auto newVal = cvtOp.getSrc ();
130-
131182 auto newPtrType = RankedTensorType::get (
132183 ptrType.getShape (), ptrType.getElementType (), newEncoding);
133184 Value newPtr = rewriter.create <triton::gpu::ConvertLayoutOp>(
134185 ptr.getLoc (), newPtrType, ptr);
135186
187+ auto newVal = cvtOp.getSrc ();
188+
136189 for (auto chainedOp : llvm::reverse (chainedOps)) {
137190 auto oldType =
138191 cast<mlir::RankedTensorType>(chainedOp->getResult (0 ).getType ());
139192 chainedOp->setOperand (0 , newVal);
140193 newVal = llvm::cast<mlir::TypedValue<RankedTensorType>>(
141194 chainedOp->getResult (0 ));
195+
142196 auto newType = mlir::RankedTensorType::get (
143197 oldType.getShape (), oldType.getElementType (), newEncoding);
144198 newVal.setType (newType);
@@ -152,9 +206,18 @@ class BypassEpilogueSMEM : public mlir::RewritePattern {
152206 newMask = rewriter.create <triton::gpu::ConvertLayoutOp>(
153207 mask.getLoc (), newMaskType, mask);
154208 }
209+ triton::StoreOp newStoreOp;
210+ if (auto mfmaLayout =
211+ dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(newEncoding)) {
212+ newStoreOp =
213+ convertMfmaLayoutForCDNA4 (rewriter, newPtr, newVal, newMask, stOp);
214+ } else {
215+ newStoreOp = rewriter.create <triton::StoreOp>(
216+ stOp.getLoc (), newPtr, newVal, newMask, stOp.getCache (),
217+ stOp.getEvict ());
218+ }
155219
156- rewriter.replaceOpWithNewOp <triton::StoreOp>(
157- stOp, newPtr, newVal, newMask, stOp.getCache (), stOp.getEvict ());
220+ rewriter.replaceOp (stOp, newStoreOp);
158221 return mlir::success ();
159222 }
160223};
0 commit comments