@@ -167,21 +167,22 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
167167 ];
168168}
169169
170- def SwizzledSharedEncodingAttr :
171- TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
170+ def SwizzledSharedEncodingAttr
171+ : TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
172+ [SharedEncodingTrait, LayoutEncodingTrait]> {
172173 let mnemonic = "swizzled_shared";
173174
174175 let description = [{
175176An encoding for tensors whose elements may be simultaneously accessed by
176- different cuda threads in the programs, via shared memory. In other words,
177+ different GPU threads in the programs, via shared memory. In other words,
177178for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
178179
179180In order to avoid shared memory bank conflicts, elements may be swizzled.
180181Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1].
181182
1821831. Basic swizzling
183184
184- #shared <{vec=1, perPhase=1, maxPhase=4, order=[1,0]}>
185+ #ttg.swizzled_shared <{vec=1, perPhase=1, maxPhase=4, order=[1,0]}>
185186 [ 0, 1, 2, 3], // xor with 0
186187 [ 5, 4, 7, 6], // xor with 1
187188 [10, 11, 8, 9], // xor with 2
@@ -192,7 +193,7 @@ out[r][c^r]).
192193
1931942. Multiple rows per phase
194195
195- #shared <{vec=1, perPhase=2, maxPhase=4, order=[1,0]}>
196+ #ttg.swizzled_shared <{vec=1, perPhase=2, maxPhase=4, order=[1,0]}>
196197 [ 0, 1, 2, 3], // phase 0 (xor with 0)
197198 [ 4, 5, 6, 7],
198199 [ 9, 8, 11, 10], // phase 1 (xor with 1)
@@ -203,7 +204,7 @@ means that pairs of 2 rows get the same swizzling.
203204
2042053. Max-phase applied
205206
206- $shared <{vec=1, perPhase=1, maxPhase=2, order=[1,0]}>
207+ #ttg.swizzled_shared <{vec=1, perPhase=1, maxPhase=2, order=[1,0]}>
207208 [ 0, 1, 2, 3], // phase 0 (xor with 0)
208209 [ 5, 4, 7, 6], // phase 1 (xor with 1)
209210 [ 8, 9, 10, 11], // phase 0
@@ -218,7 +219,7 @@ effect of limiting the maximum value of the xor to m-1.
218219
2192204. Max-phase and per-phase
220221
221- #shared <{vec=1, perPhase=2, maxPhase=2, order=[1,0]}>
222+ #ttg.swizzled_shared <{vec=1, perPhase=2, maxPhase=2, order=[1,0]}>
222223 [ 0, 1, 2, 3], // phase 0 (xor with 0)
223224 [ 4, 5, 6, 7], // phase 0
224225 [ 9, 8, 11, 10], // phase 1 (xor with 1)
@@ -234,7 +235,7 @@ maximum value of maxPhase-1. In other words, elements of row r are xor'ed with
234235
2352365. Adding vec
236237
237- #shared <{vec=2, perPhase=1, maxPhase=4, order=[1,0]}>
238+ #ttg.swizzled_shared <{vec=2, perPhase=1, maxPhase=4, order=[1,0]}>
238239 [ 0, 1, 2, 3, 4, 5, 6, 7],
239240 [10, 11, 8, 9, 14, 15, 12, 13],
240241 [20, 21, 22, 23, 16, 17, 18, 19],
@@ -383,6 +384,88 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
383384 let genVerifyDecl = 1;
384385}
385386
387+ def PaddeddSharedEncodingAttr
388+ : TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
389+ [SharedEncodingTrait, LayoutEncodingTrait]> {
390+ let mnemonic = "padded_shared";
391+
392+ let description = [{
393+ An encoding for tensors whose elements may be simultaneously accessed by
394+ different GPU threads in the programs, via shared memory. In other words,
395+ for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
396+ Compared to SwizzledSharedEncodingAttr, this encoding uses padding to avoid
397+ shared memory bank conflicts.
398+
399+ Formally, given a layout:
400+ padded_shared<[<interval_0>:+<pad_0>, <interval_1>:+<pad_1>, ...]>
401+ We insert a padding of `<pad_i>` elements after every `<interval_i>` elements.
402+ Multi interval-padding pairs are supported for flexibility of multi tiered
403+ padding schemes; they compose in an additive manner. So for a 1-D tensor element
404+ at index i, the corresponding shared memory location index is
405+ i + \sum_{k} (i / interval_k) * pad_k = 1
406+ `<interval_i>` and `<pad_i>` all need to be power of two.
407+
408+ Some concrete examples, using `eM` to mean tensor elements and `pN` to mean
409+ padding:
410+
411+ 1. Single interval-padding pair:
412+
413+ #ttg.padded_shared<[2:+2]>
414+ [e0, e1, p0, p1,
415+ e2, e3, p2, p3,
416+ ...]
417+
418+ 2. Double interval-padding pairs:
419+
420+ #ttg.padded_shared<[2:+1, 4:+2]>
421+ [e0, e1, p0,
422+ e2, e3, p1, p2, p3,
423+ e4, e5, p4,
424+ e6, e7, p5, p6, p7,
425+ ...]
426+
427+ In addition to interval-padding pairs, this encoding requires an `order` to
428+ specify the logical tensor dimenions from the fastest-to slowest-varying.
429+ It may optionally support CGA level organization like other encoding
430+ attributes too, for example,
431+ #ttg.padded_shared<[2:+1, 4:+2] {
432+ order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1],
433+ CTAOrder = [0, 1]}>
434+ }];
435+
436+ let parameters = (ins
437+ ArrayRefParameter<"unsigned">:$intervals,
438+ ArrayRefParameter<"unsigned">:$paddings,
439+ // Order of logical tensor dimensions; fastest-varying first.
440+ ArrayRefParameter<"unsigned">:$order,
441+ "CTALayoutAttr":$CTALayout
442+ );
443+
444+ let builders = [
445+ AttrBuilder<(ins "ArrayRef<std::pair<unsigned, unsigned>>":$intervalPads,
446+ "ArrayRef<unsigned>":$order, "CTALayoutAttr":$ctaLayout)>,
447+ ];
448+
449+ let extraClassDeclaration = extraBaseClassDeclaration # [{
450+ unsigned getRank() const { return getOrder().size(); }
451+ int32_t getAlignment() const { return 16; }
452+
453+ unsigned getMinInterval() const {
454+ return *llvm::min_element(getIntervals());
455+ }
456+
457+ // Returns the total number of elements including padding given the input
458+ // tensor shape.
459+ int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
460+
461+ SmallVector<unsigned> getCTAsPerCGA() const;
462+ SmallVector<unsigned> getCTAOrder() const;
463+ SmallVector<unsigned> getCTASplitNum() const;
464+ }];
465+ let hasCustomAssemblyFormat = 1;
466+ let genVerifyDecl = 1;
467+ }
468+
386469def NVMMASharedEncodingAttr :
387470 TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
388471 let mnemonic = "nvmma_shared";
0 commit comments