Skip to content

Commit 6744eea

Browse files
authored
Reland upstream commit #6007 and #6004 (#3575)
Reland "[LAYOUTS] [NFC] Make order accept a RankedTensorType (#6007)" and "[LAYOUTS] [NFC] Just accept DistributedEncodings in SliceLayout (#6004)
1 parent 9da95a7 commit 6744eea

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2182
-1043
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
137137
if (rank > 1) {
138138
// reorder the shape and constancy vectors by the axis order:
139139
// from the fastest-changing to the smallest-changing axis
140-
SmallVector<unsigned> order = getOrder(encoding);
140+
SmallVector<unsigned> order = getOrder(rtType);
141141
if (rank != order.size())
142142
return resultVals;
143143
elemsPerThread = applyPermutation(elemsPerThread, order);

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ class SharedMemoryObject {
387387
auto allocShape = memDesc.getAllocShape();
388388
auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA(
389389
memDesc.getEncoding(), allocShape);
390-
auto layoutOrder = triton::gpu::getOrder(memDesc.getEncoding());
390+
auto layoutOrder = triton::gpu::getOrder(memDesc);
391391
auto allocStrides = SharedMemoryObject::getStridesForShape(
392392
allocShapePerCTA, layoutOrder, loc, rewriter);
393393
return SmallVector<Value>(allocStrides.end() - offsets.size(),

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,19 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
134134
// the order of the elements within a thread.
135135
// For shared Layout, the order refers to which dimension of the original tensor
136136
// is contiguous in shared memory.
137-
SmallVector<unsigned> getOrder(Attribute layout);
137+
SmallVector<unsigned> getOrder(DistributedEncodingTrait layout,
138+
ArrayRef<int64_t> shape);
139+
SmallVector<unsigned> getOrder(RankedTensorType type);
140+
141+
SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
142+
ArrayRef<int64_t> shape);
143+
SmallVector<unsigned> getOrder(MemDescType type);
144+
SmallVector<unsigned> getOrder(TensorOrMemDesc type);
145+
146+
// Order of the elements in the shared memory as defined at layout creation
147+
// If this layout is associated with a MemDesc with a different shape
148+
// it may return a different order than the actual order of the elements
149+
SmallVector<unsigned> getDefaultOrder(SharedEncodingTrait layout);
138150

139151
// Returns the dimensions along which warpId's are distributed.
140152
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
@@ -143,17 +155,16 @@ SmallVector<unsigned> getOrder(Attribute layout);
143155
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
144156
// [warp0 warp2 warp4 warp6]
145157
// [warp1 warp3 warp5 warp7]
146-
// Note that in most cases, getWarpOrder and getOrder return the same results.
147-
// But this is not guaranteed.
148-
SmallVector<unsigned> getWarpOrder(Attribute layout);
158+
SmallVector<unsigned> getWarpOrder(DistributedEncodingTrait layout,
159+
ArrayRef<int64_t> shape);
160+
SmallVector<unsigned> getWarpOrder(RankedTensorType type);
149161

150162
// Returns the dimensions along which threadId's are distributed.
151163
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
152164
// distribution in the warp.
153-
// Note that, in most cases, getThreadOrder and getOrder return the same
154-
// results. But this is not guaranteed. One exception is mfma.transposed layout,
155-
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
156-
SmallVector<unsigned> getThreadOrder(Attribute layout);
165+
SmallVector<unsigned> getThreadOrder(DistributedEncodingTrait layout,
166+
ArrayRef<int64_t> shape);
167+
SmallVector<unsigned> getThreadOrder(RankedTensorType type);
157168

158169
CTALayoutAttr getCTALayout(Attribute layout);
159170

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,9 @@ def NVMMASharedEncodingAttr :
464464
SmallVector<unsigned> getCTAsPerCGA() const;
465465
SmallVector<unsigned> getCTAOrder() const;
466466
SmallVector<unsigned> getCTASplitNum() const;
467+
SmallVector<unsigned> getOrder() const {
468+
return getTransposed() ? SmallVector<unsigned>({0, 1}) : SmallVector<unsigned>({1, 0});
469+
}
467470
}];
468471
let hasCustomAssemblyFormat = 1;
469472
}
@@ -517,25 +520,33 @@ We call each individual tile "rep".
517520
"SmallVector<unsigned>",
518521
"getWarpsPerCTA">,
519522

520-
InterfaceMethod<"Get the order of the warps per CTA. The fastest-changing axis first",
521-
"SmallVector<unsigned>",
522-
"getWarpOrder">,
523523

524524
InterfaceMethod<"Get the shape of the threads per warp",
525525
"SmallVector<unsigned>",
526526
"getThreadsPerWarp">,
527527

528-
InterfaceMethod<"Get the order of the threads per warp. The fastest-changing axis first",
529-
"SmallVector<unsigned>",
530-
"getThreadOrder">,
531-
532528
InterfaceMethod<"Get the shape of the values per thread.",
533529
"SmallVector<unsigned>",
534530
"getSizePerThread">,
535531
InterfaceMethod<"Convert to LinearLayout.",
536532
"LinearLayout",
537533
"toLinearLayout",
538-
(ins "ArrayRef<int64_t>":$shape)>
534+
(ins "ArrayRef<int64_t>":$shape)>,
535+
536+
// Legacy methods: They do not take into account the shape of the tensor
537+
// that is, the fact that we use them to tile the tensor.
538+
InterfaceMethod<"Get the default order of the registers per warp. The fastest-changing axis first",
539+
"SmallVector<unsigned>",
540+
"getDefaultOrder">,
541+
542+
InterfaceMethod<"Get the default order of the threads per warp. The fastest-changing axis first",
543+
"SmallVector<unsigned>",
544+
"getDefaultThreadOrder">,
545+
546+
InterfaceMethod<"Get the default order of the warps per CTA. The fastest-changing axis first",
547+
"SmallVector<unsigned>",
548+
"getDefaultWarpOrder">
549+
539550
];
540551
}
541552

@@ -583,13 +594,16 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
583594
SmallVector<unsigned> getCTAOrder() const;
584595
SmallVector<unsigned> getCTASplitNum() const;
585596
SmallVector<unsigned> getWarpsPerCTA() const;
586-
SmallVector<unsigned> getWarpOrder() const;
587597
SmallVector<unsigned> getThreadsPerWarp() const;
588-
SmallVector<unsigned> getThreadOrder() const;
589598

590599
SmallVector<unsigned> getSizePerThread() const;
591600

592601
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
602+
603+
// Legacy methods: They do not take into account the shape of the tensor
604+
SmallVector<unsigned> getDefaultWarpOrder() const;
605+
SmallVector<unsigned> getDefaultThreadOrder() const;
606+
SmallVector<unsigned> getDefaultOrder() const;
593607
}];
594608
}
595609

@@ -620,6 +634,8 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
620634
SmallVector<unsigned> getContigPerThread() const;
621635
SmallVector<unsigned> getContigPerWarp() const;
622636
SmallVector<unsigned> getOrder() const;
637+
SmallVector<unsigned> getWarpOrder() const;
638+
SmallVector<unsigned> getThreadOrder() const;
623639

624640
// Generalizes get{Warp,Thread,CTA}Order to linear layouts.
625641
// Returns the order of the dimensions `dimName` of the layout.
@@ -1228,8 +1244,7 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
12281244
let parameters = (
12291245
ins
12301246
"unsigned":$dim,
1231-
// TODO: constraint here to only take distributed encodings
1232-
"Attribute":$parent
1247+
"DistributedEncodingTrait":$parent
12331248
);
12341249

12351250
let extraClassDeclaration = extraDistributedDeclaration # [{

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
3636

3737
int numWarps = lookupNumWarps(cvtOp);
3838
auto enc = BlockedEncodingAttr::get(
39-
ctx, srcType.getShape(), getSizePerThread(srcMma), getOrder(srcMma),
39+
ctx, srcType.getShape(), getSizePerThread(srcMma), getOrder(srcType),
4040
numWarps, threadsPerWarp, numCTAs);
4141
auto tmpType = RankedTensorType::get(dstType.getShape(),
4242
dstType.getElementType(), enc);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ static int __builtin_ctz(unsigned x) {
3232
namespace {
3333

3434
LinearLayout getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
35-
LinearLayout regLayout, Attribute dstEnc,
35+
LinearLayout regLayout,
36+
triton::gpu::SharedEncodingTrait dstEnc,
3637
int elemBitWidth) {
3738
StringAttr kBlock = StringAttr::get(ctx, ("block"));
3839
int rank = shape.size();
3940

4041
LinearLayout sharedLayout = triton::gpu::toLinearLayout(shape, dstEnc);
41-
auto sharedOrder = triton::gpu::getOrder(dstEnc);
42+
auto sharedOrder = triton::gpu::getOrder(dstEnc, shape);
4243

4344
// sharedLayout's in-dims are currently (offset, block). Reshape to
4445
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ struct MemDescSubviewOpConversion
376376
auto b = TritonLLVMOpBuilder(loc, rewriter);
377377
auto srcTy = op.getSrc().getType();
378378
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
379-
auto layoutOrder = getOrder(srcTy.getEncoding());
379+
auto layoutOrder = getOrder(srcTy);
380380

381381
// newBase = base + offset
382382
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),

0 commit comments

Comments
 (0)