Skip to content

Commit 3796f3f

Browse files
Revert "[LAYOUTS] [NFC] Make order accept a RankedTensorType (#6007)"
This reverts commit dce695e.
1 parent 2e41c90 commit 3796f3f

File tree

27 files changed

+221
-314
lines changed

27 files changed

+221
-314
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
135135
if (rank > 1) {
136136
// reorder the shape and constancy vectors by the axis order:
137137
// from the fastest-changing to the smallest-changing axis
138-
SmallVector<unsigned> order = getOrder(rtType);
138+
SmallVector<unsigned> order = getOrder(encoding);
139139
if (rank != order.size())
140140
return resultVals;
141141
elemsPerThread = applyPermutation(elemsPerThread, order);

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ class SharedMemoryObject {
363363
auto allocShape = memDesc.getAllocShape();
364364
auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA(
365365
memDesc.getEncoding(), allocShape);
366-
auto layoutOrder = triton::gpu::getOrder(memDesc);
366+
auto layoutOrder = triton::gpu::getOrder(memDesc.getEncoding());
367367
auto allocStrides = SharedMemoryObject::getStridesForShape(
368368
allocShapePerCTA, layoutOrder, loc, rewriter);
369369
return SmallVector<Value>(allocStrides.end() - offsets.size(),

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,7 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
134134
// the order of the elements within a thread.
135135
// For shared Layout, the order refers to which dimension of the original tensor
136136
// is contiguous in shared memory.
137-
SmallVector<unsigned> getOrder(DistributedEncodingTrait layout,
138-
ArrayRef<int64_t> shape);
139-
SmallVector<unsigned> getOrder(RankedTensorType type);
140-
141-
SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
142-
ArrayRef<int64_t> shape);
143-
SmallVector<unsigned> getOrder(MemDescType type);
144-
SmallVector<unsigned> getOrder(TensorOrMemDesc type);
145-
146-
// Order of the elements in the shared memory as defined at layout creation
147-
// If this layout is associated with a MemDesc with a different shape
148-
// it may return a different order than the actual order of the elements
149-
SmallVector<unsigned> getDefaultOrder(SharedEncodingTrait layout);
137+
SmallVector<unsigned> getOrder(Attribute layout);
150138

151139
// Returns the dimensions along which warpId's are distributed.
152140
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
@@ -155,16 +143,17 @@ SmallVector<unsigned> getDefaultOrder(SharedEncodingTrait layout);
155143
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
156144
// [warp0 warp2 warp4 warp6]
157145
// [warp1 warp3 warp5 warp7]
158-
SmallVector<unsigned> getWarpOrder(DistributedEncodingTrait layout,
159-
ArrayRef<int64_t> shape);
160-
SmallVector<unsigned> getWarpOrder(RankedTensorType type);
146+
// Note that in most cases, getWarpOrder and getOrder return the same results.
147+
// But this is not guaranteed.
148+
SmallVector<unsigned> getWarpOrder(Attribute layout);
161149

162150
// Returns the dimensions along which threadId's are distributed.
163151
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
164152
// distribution in the warp.
165-
SmallVector<unsigned> getThreadOrder(DistributedEncodingTrait layout,
166-
ArrayRef<int64_t> shape);
167-
SmallVector<unsigned> getThreadOrder(RankedTensorType type);
153+
// Note that, in most cases, getThreadOrder and getOrder return the same
154+
// results. But this is not guaranteed. One exception is mfma.transposed layout,
155+
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
156+
SmallVector<unsigned> getThreadOrder(Attribute layout);
168157

169158
CTALayoutAttr getCTALayout(Attribute layout);
170159

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,6 @@ def NVMMASharedEncodingAttr :
464464
SmallVector<unsigned> getCTAsPerCGA() const;
465465
SmallVector<unsigned> getCTAOrder() const;
466466
SmallVector<unsigned> getCTASplitNum() const;
467-
SmallVector<unsigned> getOrder() const {
468-
return getTransposed() ? SmallVector<unsigned>({0, 1}) : SmallVector<unsigned>({1, 0});
469-
}
470467
}];
471468
let hasCustomAssemblyFormat = 1;
472469
}
@@ -520,33 +517,25 @@ We call each individual tile "rep".
520517
"SmallVector<unsigned>",
521518
"getWarpsPerCTA">,
522519

520+
InterfaceMethod<"Get the order of the warps per CTA. The fastest-changing axis first",
521+
"SmallVector<unsigned>",
522+
"getWarpOrder">,
523523

524524
InterfaceMethod<"Get the shape of the threads per warp",
525525
"SmallVector<unsigned>",
526526
"getThreadsPerWarp">,
527527

528+
InterfaceMethod<"Get the order of the threads per warp. The fastest-changing axis first",
529+
"SmallVector<unsigned>",
530+
"getThreadOrder">,
531+
528532
InterfaceMethod<"Get the shape of the values per thread.",
529533
"SmallVector<unsigned>",
530534
"getSizePerThread">,
531535
InterfaceMethod<"Convert to LinearLayout.",
532536
"LinearLayout",
533537
"toLinearLayout",
534-
(ins "ArrayRef<int64_t>":$shape)>,
535-
536-
// Legacy methods: They do not take into account the shape of the tensor
537-
// that is, the fact that we use them to tile the tensor.
538-
InterfaceMethod<"Get the default order of the registers per warp. The fastest-changing axis first",
539-
"SmallVector<unsigned>",
540-
"getDefaultOrder">,
541-
542-
InterfaceMethod<"Get the default order of the threads per warp. The fastest-changing axis first",
543-
"SmallVector<unsigned>",
544-
"getDefaultThreadOrder">,
545-
546-
InterfaceMethod<"Get the default order of the warps per CTA. The fastest-changing axis first",
547-
"SmallVector<unsigned>",
548-
"getDefaultWarpOrder">
549-
538+
(ins "ArrayRef<int64_t>":$shape)>
550539
];
551540
}
552541

@@ -594,16 +583,13 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
594583
SmallVector<unsigned> getCTAOrder() const;
595584
SmallVector<unsigned> getCTASplitNum() const;
596585
SmallVector<unsigned> getWarpsPerCTA() const;
586+
SmallVector<unsigned> getWarpOrder() const;
597587
SmallVector<unsigned> getThreadsPerWarp() const;
588+
SmallVector<unsigned> getThreadOrder() const;
598589

599590
SmallVector<unsigned> getSizePerThread() const;
600591

601592
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
602-
603-
// Legacy methods: They do not take into account the shape of the tensor
604-
SmallVector<unsigned> getDefaultWarpOrder() const;
605-
SmallVector<unsigned> getDefaultThreadOrder() const;
606-
SmallVector<unsigned> getDefaultOrder() const;
607593
}];
608594
}
609595

@@ -634,8 +620,6 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
634620
SmallVector<unsigned> getContigPerThread() const;
635621
SmallVector<unsigned> getContigPerWarp() const;
636622
SmallVector<unsigned> getOrder() const;
637-
SmallVector<unsigned> getWarpOrder() const;
638-
SmallVector<unsigned> getThreadOrder() const;
639623

640624
// Generalizes get{Warp,Thread,CTA}Order to linear layouts.
641625
// Returns the order of the dimensions `dimName` of the layout.

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
3636

3737
int numWarps = lookupNumWarps(cvtOp);
3838
auto enc = BlockedEncodingAttr::get(
39-
ctx, srcType.getShape(), getSizePerThread(srcMma), getOrder(srcType),
39+
ctx, srcType.getShape(), getSizePerThread(srcMma), getOrder(srcMma),
4040
numWarps, threadsPerWarp, numCTAs);
4141
auto tmpType = RankedTensorType::get(dstType.getShape(),
4242
dstType.getElementType(), enc);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,13 @@ static int __builtin_ctz(unsigned x) {
3232
namespace {
3333

3434
LinearLayout getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
35-
LinearLayout regLayout,
36-
triton::gpu::SharedEncodingTrait dstEnc,
35+
LinearLayout regLayout, Attribute dstEnc,
3736
int elemBitWidth) {
3837
StringAttr kBlock = StringAttr::get(ctx, ("block"));
3938
int rank = shape.size();
4039

4140
LinearLayout sharedLayout = triton::gpu::toLinearLayout(shape, dstEnc);
42-
auto sharedOrder = triton::gpu::getOrder(dstEnc, shape);
41+
auto sharedOrder = triton::gpu::getOrder(dstEnc);
4342

4443
// sharedLayout's in-dims are currently (offset, block). Reshape to
4544
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ struct MemDescSubviewOpConversion
369369
auto b = TritonLLVMOpBuilder(loc, rewriter);
370370
auto srcTy = op.getSrc().getType();
371371
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
372-
auto layoutOrder = getOrder(srcTy);
372+
auto layoutOrder = getOrder(srcTy.getEncoding());
373373

374374
// newBase = base + offset
375375
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),

0 commit comments

Comments
 (0)