@@ -28,118 +28,18 @@ inline SmallVector<Value> translateTMAIndices(BuilderT &builder, Location loc,
2828 return indices;
2929}
3030
31- inline gpu::CTALayoutAttr updateCTALayoutForShape (gpu::CTALayoutAttr ctaLayout,
32- ArrayRef<int64_t > shape) {
33- auto rank = shape.size ();
34- if (ctaLayout.getRank () == rank)
35- return ctaLayout;
31+ gpu::CTALayoutAttr updateCTALayoutForShape (gpu::CTALayoutAttr ctaLayout,
32+ ArrayRef<int64_t > shape);
3633
37- auto ctx = ctaLayout.getContext ();
38- if (ctaLayout.getRank () > rank) {
39- unsigned rankDiff = ctaLayout.getRank () - rank;
40- return gpu::CTALayoutAttr::get (
41- ctx, ctaLayout.getCTAsPerCGA ().drop_front (rankDiff),
42- ctaLayout.getCTASplitNum ().drop_front (rankDiff),
43- ctaLayout.getCTAOrder ().drop_front (rankDiff));
44- }
45- // For rank-reducing loads, we need to rank-increase the CTA Layout
46- auto rankDiff = rank - ctaLayout.getRank ();
47- for (unsigned i = 0 ; i < rankDiff; ++i) {
48- assert (shape[i] == 1 && " Should only happen for rank-reducing loads" );
49- }
50- SmallVector<unsigned > CTAsPerCGA (rank, 1 );
51- SmallVector<unsigned > CTASplitNum (rank, 1 );
52- SmallVector<unsigned > CTAOrder (rank, 1 );
53-
54- llvm::copy (ctaLayout.getCTAsPerCGA (), CTAsPerCGA.begin () + rankDiff);
55- llvm::copy (ctaLayout.getCTASplitNum (), CTASplitNum.begin () + rankDiff);
56- for (unsigned i = 0 ; i < rankDiff; ++i) {
57- CTAOrder[i] = rank - i;
58- }
59- llvm::copy (ctaLayout.getCTAOrder (), CTAOrder.begin () + rankDiff);
60- return gpu::CTALayoutAttr::get (ctx, CTAsPerCGA, CTASplitNum, CTAOrder);
61- }
62-
63- inline gpu::SharedEncodingTrait
34+ gpu::SharedEncodingTrait
6435updateEncodingForShape (Operation *op, gpu::SharedEncodingTrait encoding,
65- RankedTensorType tensorType) {
66- auto ctx = encoding.getContext ();
67- auto ctaLayout = gpu::getCTALayout (encoding);
68- if (auto nvmmaEnc = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding)) {
69- auto existingCta = nvmmaEnc.getCTALayout ();
70- if (!existingCta)
71- return nvmmaEnc;
72-
73- auto newCtaEnc = updateCTALayoutForShape (ctaLayout, tensorType.getShape ());
74- return gpu::NVMMASharedEncodingAttr::get (
75- ctx, nvmmaEnc.getSwizzlingByteWidth (), nvmmaEnc.getTransposed (),
76- nvmmaEnc.getElementBitWidth (), nvmmaEnc.getFp4Padded (), newCtaEnc);
77- }
78- if (auto swizEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(encoding)) {
79- auto existingCta = swizEnc.getCTALayout ();
80- if (!existingCta)
81- return swizEnc;
82-
83- auto rank = tensorType.getRank ();
84- auto oldOrder = swizEnc.getOrder ();
85- SmallVector<unsigned > order;
86- for (int i = 0 ; i + oldOrder.size () < rank; ++i)
87- order.push_back (rank - i - 1 );
88- for (int i = 0 ; i < oldOrder.size (); ++i) {
89- // If it is a rank-reducing load, we need to drop the last dimensions.
90- if (oldOrder[i] >= rank)
91- continue ;
92- order.push_back (oldOrder[i]);
93- }
94- auto newCtaEnc = updateCTALayoutForShape (ctaLayout, tensorType.getShape ());
95- return gpu::SwizzledSharedEncodingAttr::get (
96- ctx, swizEnc.getVec (), swizEnc.getPerPhase (), swizEnc.getMaxPhase (),
97- order, newCtaEnc);
98- }
99-
100- constexpr auto msg = " Internal Error: Unhandled tensor descriptor encoding" ;
101- if (op)
102- op->emitError () << msg;
103- llvm::report_fatal_error (msg);
104- }
36+ RankedTensorType tensorType);
10537
106- inline triton::gpu::SharedEncodingTrait
38+ triton::gpu::SharedEncodingTrait
10739getEncodingFromDescriptor (Operation *op, RankedTensorType tensorType,
108- Value desc) {
109- auto descBlockType = cast<TensorDescType>(desc.getType ()).getBlockType ();
110- Attribute encoding = descBlockType.getEncoding ();
111- if (!encoding) {
112- constexpr auto msg =
113- " Internal Error: Tensor descriptor should have encoding set" ;
114- if (op)
115- op->emitError () << msg;
116- llvm::report_fatal_error (msg);
117- }
118- auto sharedEnc = cast<gpu::SharedEncodingTrait>(encoding);
119- if (descBlockType.getShape () == tensorType.getShape ())
120- return sharedEnc;
121-
122- return updateEncodingForShape (op, sharedEnc, tensorType);
123- }
40+ Value desc);
12441
125- inline int64_t getTMAContigDim (Attribute encoding, ArrayRef<int64_t > shape) {
126- assert (encoding);
127- auto mmaEncoding =
128- llvm::dyn_cast_or_null<gpu::NVMMASharedEncodingAttr>(encoding);
129-
130- // The bounding box inner dimension must be less than or equal to the
131- // swizzle size.
132- // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
133- // We clamp the block size and the codegen will emit multiple copy
134- // operations.
135- if (mmaEncoding) {
136- auto elemSize = mmaEncoding.getElementBitWidth () / 8 ;
137- return mmaEncoding.getSwizzlingByteWidth () / elemSize;
138- }
139-
140- auto shapePerCTA = gpu::getShapePerCTA (encoding, shape);
141- return shapePerCTA.back ();
142- }
42+ int64_t getTMAContigDim (Attribute encoding, ArrayRef<int64_t > shape);
14343
14444inline int64_t getTMAContigDim (RankedTensorType tensorType) {
14545 return getTMAContigDim (tensorType.getEncoding (), tensorType.getShape ());
@@ -149,61 +49,9 @@ inline int64_t getTMAContigDim(gpu::MemDescType memDescType) {
14949 return getTMAContigDim (memDescType.getEncoding (), memDescType.getShape ());
15050}
15151
152- inline std::optional<int > getTMASwizzleMode (Operation *op, TensorDescType ty) {
153- auto encoding = ty.getBlockType ().getEncoding ();
154- auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
155- unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth () : 0 ;
156- if (!mmaEncoding) {
157- auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(encoding);
158- if (!swizzledEnc || swizzledEnc.getVec () != 1 ||
159- swizzledEnc.getPerPhase () != 1 || swizzledEnc.getMaxPhase () != 1 ) {
160- if (op)
161- op->emitError (" Unhandled encoding type" );
162- return std::nullopt ;
163- }
164- }
165-
166- bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded ();
167- assert (!fp4Padded || swizzleBytes == 128 &&
168- " elem type .b4x16_p64 supports only 128B swizzling" );
52+ std::optional<int > getTMASwizzleMode (Operation *op, TensorDescType ty);
16953
170- int32_t swizzleMode = 0 ;
171- if (swizzleBytes == 128 ) {
172- swizzleMode = 3 ;
173- } else if (swizzleBytes == 64 ) {
174- swizzleMode = 2 ;
175- } else if (swizzleBytes == 32 ) {
176- swizzleMode = 1 ;
177- }
178- return swizzleMode;
179- }
180-
181- inline std::optional<int > getTMAElementType (Operation *op, TensorDescType ty) {
182- auto encoding = ty.getBlockType ().getEncoding ();
183- auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
184- bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded ();
185-
186- if (fp4Padded)
187- return 14 ; // .b4x16_p64
188-
189- auto elemSize = ty.getBlockType ().getElementTypeBitWidth () / 8 ;
190- switch (elemSize) {
191- case 1 :
192- return 0 ;
193- case 2 :
194- return 1 ;
195- case 4 :
196- return 2 ;
197- default :
198- break ;
199- }
200- if (op) {
201- op->emitError ()
202- << " Tensor descriptor element type must have size 1, 2, or 4 but got "
203- << elemSize;
204- }
205- return std::nullopt ;
206- }
54+ std::optional<int > getTMAElementType (Operation *op, TensorDescType ty);
20755
20856template <typename BuilderT>
20957mlir::LogicalResult createTMADesc (mlir::Value tmaPtr,
0 commit comments