11#include " mlir/IR/TypeUtilities.h"
22#include " mlir/Pass/PassManager.h"
3- #include " mlir/Transforms/Passes.h"
4- #include " triton/Analysis/AxisInfo.h"
53#include " triton/Dialect/Triton/IR/Dialect.h"
64#include " triton/Dialect/Triton/IR/Types.h"
75#include " triton/Dialect/TritonGPU/IR/Attributes.h"
86#include " triton/Dialect/TritonGPU/IR/Dialect.h"
97#include " triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
10- #include " triton/Dialect/TritonGPU/Transforms/Passes.h"
118#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
129#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1310#include " triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
1411#include " triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
15- #include " triton/Tools/Sys/GetEnv.hpp"
1612#include " llvm/ADT/PriorityWorklist.h"
17- #include " llvm/ADT/Sequence.h"
18- #include " llvm/Support/Casting.h"
19- #include " llvm/Support/VersionTuple.h"
13+ #include < algorithm>
2014#include < memory>
2115#include < unordered_set>
2216
@@ -35,6 +29,7 @@ struct UseInfo {
3529 TypedValue<tt::TensorDescType> descriptor;
3630 Operation *use;
3731 Attribute desiredSharedEncoding;
32+ SmallVector<int64_t > shape;
3833 ttg::CTALayoutAttr ctaLayout;
3934};
4035
@@ -72,6 +67,14 @@ ttg::CTALayoutAttr getCtaLayoutFromEncoding(Attribute encoding) {
7267 layout.getCTASplitNum (), layout.getCTAOrder ());
7368}
7469
70+ SmallVector<int64_t > expandToRank (ArrayRef<int64_t > shape, int rank) {
71+ SmallVector<int64_t > result (rank, 1 );
72+ assert (shape.size () <= rank);
73+ auto rankDiff = rank - shape.size ();
74+ std::copy (shape.begin (), shape.end (), result.begin () + rankDiff);
75+ return result;
76+ }
77+
7578std::optional<UseInfo> getUseInfo (Operation *op) {
7679 UseInfo info;
7780 info.use = op;
@@ -81,6 +84,9 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
8184 auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
8285 : load.getType ().getEncoding ();
8386 info.ctaLayout = ttg::getCTALayout (encoding);
87+ auto shape = load.getResult ().getType ().getShape ();
88+ auto rank = load.getDesc ().getType ().getBlockType ().getRank ();
89+ info.shape = expandToRank (shape, rank);
8490 return info;
8591 }
8692 if (auto gather = dyn_cast<tt::DescriptorGatherOp>(op)) {
@@ -89,18 +95,27 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
8995 auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
9096 : gather.getType ().getEncoding ();
9197 info.ctaLayout = ttg::getCTALayout (encoding);
98+ auto shape = gather.getResult ().getType ().getShape ();
99+ auto rank = gather.getDesc ().getType ().getBlockType ().getRank ();
100+ info.shape = expandToRank (shape, rank);
92101 return info;
93102 }
94103 if (auto store = dyn_cast<tt::DescriptorStoreOp>(op)) {
95104 info.descriptor = store.getDesc ();
96105 auto encoding = store.getSrc ().getType ().getEncoding ();
97106 info.ctaLayout = ttg::getCTALayout (encoding);
107+ auto shape = store.getSrc ().getType ().getShape ();
108+ auto rank = store.getDesc ().getType ().getBlockType ().getRank ();
109+ info.shape = expandToRank (shape, rank);
98110 return info;
99111 }
100112 if (auto scatter = dyn_cast<tt::DescriptorScatterOp>(op)) {
101113 info.descriptor = scatter.getDesc ();
102114 auto encoding = scatter.getSrc ().getType ().getEncoding ();
103115 info.ctaLayout = ttg::getCTALayout (encoding);
116+ auto shape = scatter.getSrc ().getType ().getShape ();
117+ auto rank = scatter.getDesc ().getType ().getBlockType ().getRank ();
118+ info.shape = expandToRank (shape, rank);
104119 return info;
105120 }
106121 return std::nullopt ;
@@ -109,12 +124,15 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
109124struct EncodingInfo {
110125 Attribute desiredEncoding;
111126 ttg::CTALayoutAttr ctaLayout;
127+ // Shape may be different from the descriptor block shape for gather/scatter
128+ // use case
129+ SmallVector<int64_t > shape;
112130 bool forcedToDefault = false ;
113131
114132 bool operator ==(const EncodingInfo &other) const {
115133 return desiredEncoding == other.desiredEncoding &&
116134 ctaLayout == other.ctaLayout &&
117- forcedToDefault == other.forcedToDefault ;
135+ forcedToDefault == other.forcedToDefault && shape == other. shape ;
118136 }
119137};
120138
@@ -123,7 +141,8 @@ struct EncodingInfo {
123141template <> struct std ::hash<EncodingInfo> {
124142 size_t operator ()(const EncodingInfo &einfo) const {
125143 return llvm::hash_combine (einfo.desiredEncoding , einfo.ctaLayout ,
126- einfo.forcedToDefault );
144+ einfo.forcedToDefault ,
145+ ArrayRef<int64_t >(einfo.shape ));
127146 }
128147};
129148
@@ -172,6 +191,21 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
172191 // Always propagate forcedToDefault
173192 result.forcedToDefault = lhs.forcedToDefault || rhs.forcedToDefault ;
174193
194+ if (result.forcedToDefault )
195+ return result;
196+
197+ if (lhs.shape .empty () || lhs.shape == rhs.shape )
198+ result.shape = rhs.shape ;
199+ else if (rhs.shape .empty ())
200+ result.shape = lhs.shape ;
201+ else {
202+ assert (lhs.shape .size () == rhs.shape .size ());
203+ auto rank = lhs.shape .size ();
204+ result.shape .reserve (rank);
205+ for (int i = 0 ; i < rank; ++i)
206+ result.shape .push_back (std::min (lhs.shape [i], rhs.shape [i]));
207+ }
208+
175209 SetVector<ttg::CTALayoutAttr> ctaLayouts;
176210 if (lhs.ctaLayout )
177211 ctaLayouts.insert (lhs.ctaLayout );
@@ -190,9 +224,6 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
190224 break ;
191225 }
192226
193- if (result.forcedToDefault )
194- return result;
195-
196227 SetVector<Attribute> desiredEncodings;
197228 if (lhs.desiredEncoding )
198229 desiredEncodings.insert (lhs.desiredEncoding );
@@ -213,23 +244,32 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
213244}
214245
215246Attribute getFallbackSharedEncoding (RankedTensorType tensorType,
216- ttg::CTALayoutAttr ctaLayout) {
247+ ttg::CTALayoutAttr ctaLayout,
248+ ArrayRef<int64_t > usageShape) {
217249 auto ctx = tensorType.getContext ();
218250 SmallVector<unsigned > order;
219251 for (int i = tensorType.getRank () - 1 ; i >= 0 ; --i)
220252 order.push_back (i);
221253
254+ ArrayRef<int64_t > shape =
255+ usageShape.empty () ? tensorType.getShape () : usageShape;
222256 if (!ctaLayout)
223257 ctaLayout = ttg::CTALayoutAttr::getDefault (ctx, tensorType.getRank ());
224258 else if (ctaLayout.getRank () != tensorType.getRank ())
225- ctaLayout = ttng::updateCTALayoutForShape (ctaLayout, tensorType.getShape ());
259+ ctaLayout = ttng::updateCTALayoutForShape (ctaLayout, shape);
260+
261+ auto elemTy = tensorType.getElementType ();
262+ auto shapePerCTA = ttg::getShapePerCTA (ctaLayout.getCTASplitNum (), shape);
263+ unsigned eleBitWidth = tensorType.getElementType ().getIntOrFloatBitWidth ();
226264
227- if (tensorType.getRank () == 1 ) {
265+ auto contigDimSizeInBytes = shapePerCTA.back () * eleBitWidth / 8 ;
266+ auto rank = tensorType.getRank ();
267+ if (rank == 1 || contigDimSizeInBytes < 32 || shape[rank - 2 ] < 8 ) {
228268 return ttg::SwizzledSharedEncodingAttr::get (ctx, 1 , 1 , 1 , order, ctaLayout);
229269 }
230- return ttg::NVMMASharedEncodingAttr::get (
231- ctx, tensorType. getShape (), order, ctaLayout, tensorType.getElementType (),
232- /* fp4Padded*/ false );
270+ return ttg::NVMMASharedEncodingAttr::get (ctx, shape, order, ctaLayout,
271+ tensorType.getElementType (),
272+ /* fp4Padded*/ false );
233273}
234274
235275tt::TensorDescType getTensorDescTypeWithEncoding (Operation *op,
@@ -274,17 +314,19 @@ void assignMemoryLayouts(tt::FuncOp &func) {
274314 // fallback to default encoding
275315 for (auto blockArg : func.getBlocks ().front ().getArguments ())
276316 if (auto desc = dyn_cast<TypedValue<tt::TensorDescType>>(blockArg))
277- updateEncoding ({desc}, EncodingInfo{{}, {}, /* forcedToDefault=*/ true });
317+ updateEncoding ({desc},
318+ EncodingInfo{{}, {}, {}, /* forcedToDefault=*/ true });
278319
279320 func.walk ([&](Operation *op) {
280321 if (auto info = getUseInfo (op)) {
281- updateEncoding (info->descriptor , EncodingInfo{info->desiredSharedEncoding ,
282- info->ctaLayout });
322+ updateEncoding (info->descriptor ,
323+ EncodingInfo{info->desiredSharedEncoding , info->ctaLayout ,
324+ info->shape });
283325 } else {
284326 bool forcedToDefault =
285327 isa<tt::CallOp, tt::ReturnOp, tt::ReinterpretTensorDescOp>(op);
286328 auto einfo =
287- internEncoding (encodings, EncodingInfo{{}, {}, forcedToDefault});
329+ internEncoding (encodings, EncodingInfo{{}, {}, {}, forcedToDefault});
288330
289331 auto setEncoding = [&](Value v) {
290332 auto typedVal = cast<TypedValue<tt::TensorDescType>>(v);
@@ -344,9 +386,10 @@ void assignMemoryLayouts(tt::FuncOp &func) {
344386 if (einfo->desiredEncoding ) {
345387 newEncoding = einfo->desiredEncoding ;
346388 } else if (einfo->forcedToDefault ) {
347- newEncoding = getFallbackSharedEncoding (existingTy, {});
389+ newEncoding = getFallbackSharedEncoding (existingTy, {}, {} );
348390 } else {
349- newEncoding = getFallbackSharedEncoding (existingTy, einfo->ctaLayout );
391+ newEncoding =
392+ getFallbackSharedEncoding (existingTy, einfo->ctaLayout , einfo->shape );
350393 }
351394 desc.setType (getTensorDescTypeWithEncoding (desc.getDefiningOp (), existingTy,
352395 newEncoding));
@@ -356,14 +399,14 @@ void assignMemoryLayouts(tt::FuncOp &func) {
356399 SmallVector<Type> resultTys (func.getResultTypes ());
357400 for (auto [i, argTy] : llvm::enumerate (argTys)) {
358401 if (auto descTy = dyn_cast<tt::TensorDescType>(argTy)) {
359- auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {});
402+ auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {}, {} );
360403 argTys[i] = getTensorDescTypeWithEncoding (nullptr , descTy.getBlockType (),
361404 encoding);
362405 }
363406 }
364407 for (auto [i, resultTy] : llvm::enumerate (resultTys)) {
365408 if (auto descTy = dyn_cast<tt::TensorDescType>(resultTy)) {
366- auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {});
409+ auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {}, {} );
367410 resultTys[i] = getTensorDescTypeWithEncoding (
368411 nullptr , descTy.getBlockType (), encoding);
369412 }
0 commit comments