@@ -6,7 +6,7 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
66include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
77
88//===----------------------------------------------------------------------===//
9- // Traits and Interfaces
9+ // Traits, Interfaces and shared Parameters
1010//===----------------------------------------------------------------------===//
1111
1212def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
@@ -55,6 +55,11 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
5555def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
5656 SharedEncodingTrait, ["getAlignment"]>;
5757
58+ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
59+ "linear layout"> {
60+ let cppAccessorType = "const LinearLayout &";
61+ }
62+
5863//===----------------------------------------------------------------------===//
5964// Base Attribute
6065//===----------------------------------------------------------------------===//
@@ -369,14 +374,15 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
369374
370375def PaddedSharedEncodingAttr
371376 : TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
372- [SharedEncodingTrait, LayoutEncodingTrait ]> {
377+ [SharedEncodingTrait, DeclareLayoutEncodingMethods ]> {
373378 let mnemonic = "padded_shared";
374379
375380 let description = [{
376381An encoding for tensors whose elements may be simultaneously accessed by
377382different GPU threads in the programs, via shared memory. In other words,
378383for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
379- Compared to SwizzledSharedEncodingAttr, this encoding uses padding to avoid
384+ Compared to SwizzledSharedEncodingAttr, this encoding combines padding with
385+ element reordering via linear transformation (e.g. row permutation) to avoid
380386shared memory bank conflicts.
381387
382388Formally, given a layout:
@@ -388,48 +394,93 @@ at index i, the corresponding shared memory location index is
388394 i + \sum_{k} (i / interval_k) * pad_k = 1
389395`<interval_i>` and `<pad_i>` all need to be power of two.
390396
391- Some concrete examples, using `eM` to mean tensor elements and `pN` to mean
392- padding:
397+ Some concrete examples ignoring the linear component , using `eM` to mean tensor
398+ elements and `pN` to mean padding:
393399
3944001. Single interval-padding pair:
395401
396- #ttg.padded_shared<[2:+2]>
402+ #ttg.padded_shared<[2:+2], {...} >
397403 [e0, e1, p0, p1,
398404 e2, e3, p2, p3,
399405 ...]
400406
4014072. Double interval-padding pairs:
402408
403- #ttg.padded_shared<[2:+1, 4:+2]>
409+ #ttg.padded_shared<[2:+1, 4:+2], {...} >
404410 [e0, e1, p0,
405411 e2, e3, p1, p2, p3,
406412 e4, e5, p4,
407413 e6, e7, p5, p6, p7,
408414 ...]
409415
410- In addition to interval-padding pairs, this encoding requires an `order` to
411- specify the logical tensor dimenions from the fastest-to slowest-varying.
412- It may optionally support CGA level organization like other encoding
413- attributes too, for example,
414- #ttg.padded_shared<[2:+1, 4:+2] {
415- order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1],
416- CTAOrder = [0, 1]}>
416+ Furthermore this encoding allows for a linear remapping from the 1-D shared
417+ memory offset to logical n-D tensor elements. The remapping is given in the form
418+ of linear bases mapping from offset to [dim0, dim1...dimN-1].
419+ See LinearLayout.h for more details how linear layouts are applied to remap
420+ elements.
421+ Some concrete examples using `xN` and `yN` to mean the logical n-D tensor elements
422+ and `pN` to mean padding:
423+
424+ 1. 1D Single interval-padding with strided elements
425+
426+ #ttg.padded_shared<[2:+2] {offset = [[2], [1]], block = []}>
427+ [x0, x2, p0 p1,
428+ x1, x3, p2, p3
429+ ...]
430+
431+ 2. 2D single interval-padding with rearanged rows.
432+
433+ #ttg.padded_shared<[16:+1] {offset = [[0, 1], [0, 2], /*gap, stride by 2 rows*/[2, 0], [4, 0], [1, 0]]], block = []}>
434+ [
435+ x0y0, x0y1, x0y2, x0y3,
436+ x2y0, x2y1, x2y2, x2y3,
437+ x4y0, x4y1, x4y2, x4y3,
438+ x6y0, x6y1, x6y2, x6y3,
439+ p0,
440+ x1y0, x1y1, x1y2, x1y3,
441+ x3y0, x3y1, x3y2, x3y3,
442+ x5y0, x5y1, x5y2, x5y3,
443+ x7y0, x7y1, x7y2, x7y3,
444+ p1,
445+ ]
446+
447+ For identity mappings a short form based on order and shape is used to increase readability. The following two encodings are the same:
448+
449+ #ttg.padded_shared<[2:+2] {order = [1, 0], shape = [16, 32]}>
450+ #ttg.padded_shared<[2:+2] {offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4, 0], [8, 0]], block = []}>
451+
452+
417453 }];
418454
419455 let parameters = (ins
420456 ArrayRefParameter<"unsigned">:$intervals,
421457 ArrayRefParameter<"unsigned">:$paddings,
422- // Order of logical tensor dimensions; fastest-varying first.
423- ArrayRefParameter<"unsigned">:$order,
424- "CTALayoutAttr":$CTALayout
458+ LinearLayoutParam:$linearComponent
425459 );
426460
427461 let builders = [
428462 AttrBuilder<(ins "ArrayRef<std::pair<unsigned, unsigned>>":$intervalPads,
429- "ArrayRef<unsigned>":$order, "CTALayoutAttr":$ctaLayout)>,
463+ "LinearLayout":$linearComponent)>,
464+
465+ // Builder to create an identity mapping as the linear component
466+ AttrBuilder<(ins "ArrayRef<std::pair<unsigned, unsigned>>":$intervalPads,
467+ "ArrayRef<unsigned>":$order, "ArrayRef<int64_t>":$shape,
468+ "CTALayoutAttr":$ctaLayout)>,
430469 ];
431470
432471 let extraClassDeclaration = extraBaseClassDeclaration # [{
472+ // Returns the order of the dimensions `dimName` of the layout.
473+ // If more than dimension is of size one, it uses defaultOrder to determine
474+ // the order of the dimensions of size one.
475+ SmallVector<unsigned> orderPerDim(StringAttr dimName,
476+ ArrayRef<unsigned> defaultOrder) const;
477+ SmallVector<unsigned> getOrder() const;
478+
479+ // Returns the bases of the dimensions `dimName` of the linear_component.
480+ // If skipBroadcast is false, we count a base zero
481+ SmallVector<unsigned> basesPerDim(StringAttr dimName,
482+ bool skipBroadcast = true) const;
483+
433484 unsigned getMinInterval() const {
434485 return *llvm::min_element(getIntervals());
435486 }
@@ -708,11 +759,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
708759// Linear Layout Encoding
709760//===----------------------------------------------------------------------===//
710761
711- def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
712- "linear layout"> {
713- let cppAccessorType = "const LinearLayout &";
714- }
715-
716762def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods]> {
717763 let mnemonic = "linear";
718764
0 commit comments