Skip to content

Commit dc6e17a

Browse files
Reland "[LAYOUTS] Move order to LinearEncoding implementation (#6243)"" (#3775)
Closes #3722, #3758.
2 parents bf46a53 + d6b4ebc commit dc6e17a

File tree

21 files changed

+146
-439
lines changed

21 files changed

+146
-439
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,44 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
137137
// is contiguous in shared memory.
138138
SmallVector<unsigned> getOrder(DistributedEncodingTrait layout,
139139
ArrayRef<int64_t> shape);
140-
SmallVector<unsigned> getOrder(RankedTensorType type);
140+
inline SmallVector<unsigned> getOrder(RankedTensorType type) {
141+
return getOrder(cast<DistributedEncodingTrait>(type.getEncoding()),
142+
type.getShape());
143+
}
141144

142145
SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
143146
ArrayRef<int64_t> shape);
144-
SmallVector<unsigned> getOrder(MemDescType type);
145-
SmallVector<unsigned> getOrder(TensorOrMemDesc type);
147+
inline SmallVector<unsigned> getOrder(MemDescType type) {
148+
return getOrder(cast<SharedEncodingTrait>(type.getEncoding()),
149+
type.getShape());
150+
}
151+
inline SmallVector<unsigned> getOrder(TensorOrMemDesc type) {
152+
if (auto memDesc = dyn_cast<MemDescType>(type)) {
153+
return getOrder(memDesc);
154+
} else {
155+
auto tensorTy = cast<RankedTensorType>(type);
156+
return getOrder(tensorTy);
157+
}
158+
}
146159

147-
// Order of the elements in the shared memory as defined at layout creation
148-
// If this layout is associated with a MemDesc with a different shape
149-
// it may return a different order than the actual order of the elements
150-
SmallVector<unsigned> getDefaultOrder(SharedEncodingTrait layout);
160+
// To be removed once we implement arbitrary swizzled layouts
161+
// It chooses heuristically an order for the memory layout in which to save
162+
// a distributed layout taking into account the order of the elements
163+
// and the threads.
164+
SmallVector<unsigned> getOrderForMemory(DistributedEncodingTrait layout,
165+
ArrayRef<int64_t> shape);
166+
inline SmallVector<unsigned> getOrderForMemory(RankedTensorType type) {
167+
return getOrderForMemory(cast<DistributedEncodingTrait>(type.getEncoding()),
168+
type.getShape());
169+
}
170+
inline SmallVector<unsigned> getOrderForMemory(TensorOrMemDesc type) {
171+
if (auto memDesc = dyn_cast<MemDescType>(type)) {
172+
return getOrder(memDesc);
173+
} else {
174+
auto tensorTy = cast<RankedTensorType>(type);
175+
return getOrderForMemory(tensorTy);
176+
}
177+
}
151178

152179
// Returns the dimensions along which warpId's are distributed.
153180
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
@@ -158,14 +185,20 @@ SmallVector<unsigned> getDefaultOrder(SharedEncodingTrait layout);
158185
// [warp1 warp3 warp5 warp7]
159186
SmallVector<unsigned> getWarpOrder(DistributedEncodingTrait layout,
160187
ArrayRef<int64_t> shape);
161-
SmallVector<unsigned> getWarpOrder(RankedTensorType type);
188+
inline SmallVector<unsigned> getWarpOrder(RankedTensorType type) {
189+
return getWarpOrder(cast<DistributedEncodingTrait>(type.getEncoding()),
190+
type.getShape());
191+
}
162192

163193
// Returns the dimensions along which threadId's are distributed.
164194
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
165195
// distribution in the warp.
166196
SmallVector<unsigned> getThreadOrder(DistributedEncodingTrait layout,
167197
ArrayRef<int64_t> shape);
168-
SmallVector<unsigned> getThreadOrder(RankedTensorType type);
198+
inline SmallVector<unsigned> getThreadOrder(RankedTensorType type) {
199+
return getThreadOrder(cast<DistributedEncodingTrait>(type.getEncoding()),
200+
type.getShape());
201+
}
169202

170203
CTALayoutAttr getCTALayout(Attribute layout);
171204

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -606,21 +606,6 @@ We call each individual tile "rep".
606606
"LinearLayout",
607607
"toLinearLayout",
608608
(ins "ArrayRef<int64_t>":$shape)>,
609-
610-
// Legacy methods: They do not take into account the shape of the tensor
611-
// that is, the fact that we use them to tile the tensor.
612-
InterfaceMethod<"Get the default order of the registers per warp. The fastest-changing axis first",
613-
"SmallVector<unsigned>",
614-
"getDefaultOrder">,
615-
616-
InterfaceMethod<"Get the default order of the threads per warp. The fastest-changing axis first",
617-
"SmallVector<unsigned>",
618-
"getDefaultThreadOrder">,
619-
620-
InterfaceMethod<"Get the default order of the warps per CTA. The fastest-changing axis first",
621-
"SmallVector<unsigned>",
622-
"getDefaultWarpOrder">
623-
624609
];
625610
}
626611

@@ -662,6 +647,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
662647
}];
663648

664649
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
650+
unsigned getRank() const { return getCTAOrder().size(); }
665651
// Implemented in subclasses
666652
SmallVector<unsigned> getRepOrder() const;
667653
SmallVector<unsigned> getCTAsPerCGA() const;
@@ -670,11 +656,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
670656
SmallVector<unsigned> getWarpsPerCTA() const;
671657

672658
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
673-
674-
// Legacy methods: They do not take into account the shape of the tensor
675-
SmallVector<unsigned> getDefaultWarpOrder() const;
676-
SmallVector<unsigned> getDefaultThreadOrder() const;
677-
SmallVector<unsigned> getDefaultOrder() const;
678659
}];
679660
}
680661

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
121121
Attribute dstLayout = dstTy.getEncoding();
122122

123123
assert(cvtNeedsSharedMemory(srcTy, dstTy));
124-
auto outOrd = gpu::toLinearEncoding(dstLayout, dstTy.getShape()).getOrder();
124+
auto outOrd = gpu::getOrder(dstTy);
125125
scratchConfig.order = outOrd;
126126

127127
std::tie(scratchConfig.inVec, scratchConfig.outVec) =

0 commit comments

Comments
 (0)