@@ -137,17 +137,44 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
137137// is contiguous in shared memory.
138138SmallVector<unsigned > getOrder (DistributedEncodingTrait layout,
139139 ArrayRef<int64_t > shape);
140- SmallVector<unsigned > getOrder (RankedTensorType type);
140+ inline SmallVector<unsigned > getOrder (RankedTensorType type) {
141+ return getOrder (cast<DistributedEncodingTrait>(type.getEncoding ()),
142+ type.getShape ());
143+ }
141144
142145SmallVector<unsigned > getOrder (SharedEncodingTrait layout,
143146 ArrayRef<int64_t > shape);
144- SmallVector<unsigned > getOrder (MemDescType type);
145- SmallVector<unsigned > getOrder (TensorOrMemDesc type);
147+ inline SmallVector<unsigned > getOrder (MemDescType type) {
148+ return getOrder (cast<SharedEncodingTrait>(type.getEncoding ()),
149+ type.getShape ());
150+ }
151+ inline SmallVector<unsigned > getOrder (TensorOrMemDesc type) {
152+ if (auto memDesc = dyn_cast<MemDescType>(type)) {
153+ return getOrder (memDesc);
154+ } else {
155+ auto tensorTy = cast<RankedTensorType>(type);
156+ return getOrder (tensorTy);
157+ }
158+ }
146159
147- // Order of the elements in the shared memory as defined at layout creation
148- // If this layout is associated with a MemDesc with a different shape
149- // it may return a different order than the actual order of the elements
150- SmallVector<unsigned > getDefaultOrder (SharedEncodingTrait layout);
160+ // To be removed once we implement arbitrary swizzled layouts
161+ // It chooses heuristically an order for the memory layout in which to save
162+ // a distributed layout taking into account the order of the elements
163+ // and the threads.
164+ SmallVector<unsigned > getOrderForMemory (DistributedEncodingTrait layout,
165+ ArrayRef<int64_t > shape);
166+ inline SmallVector<unsigned > getOrderForMemory (RankedTensorType type) {
167+ return getOrderForMemory (cast<DistributedEncodingTrait>(type.getEncoding ()),
168+ type.getShape ());
169+ }
170+ inline SmallVector<unsigned > getOrderForMemory (TensorOrMemDesc type) {
171+ if (auto memDesc = dyn_cast<MemDescType>(type)) {
172+ return getOrder (memDesc);
173+ } else {
174+ auto tensorTy = cast<RankedTensorType>(type);
175+ return getOrderForMemory (tensorTy);
176+ }
177+ }
151178
152179// Returns the dimensions along which warpId's are distributed.
153180// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
@@ -158,14 +185,20 @@ SmallVector<unsigned> getDefaultOrder(SharedEncodingTrait layout);
158185// [warp1 warp3 warp5 warp7]
159186SmallVector<unsigned > getWarpOrder (DistributedEncodingTrait layout,
160187 ArrayRef<int64_t > shape);
161- SmallVector<unsigned > getWarpOrder (RankedTensorType type);
188+ inline SmallVector<unsigned > getWarpOrder (RankedTensorType type) {
189+ return getWarpOrder (cast<DistributedEncodingTrait>(type.getEncoding ()),
190+ type.getShape ());
191+ }
162192
163193// Returns the dimensions along which threadId's are distributed.
164194// Similar to warpOrder, threadOrder is necessary to tell the specific thread
165195// distribution in the warp.
166196SmallVector<unsigned > getThreadOrder (DistributedEncodingTrait layout,
167197 ArrayRef<int64_t > shape);
168- SmallVector<unsigned > getThreadOrder (RankedTensorType type);
198+ inline SmallVector<unsigned > getThreadOrder (RankedTensorType type) {
199+ return getThreadOrder (cast<DistributedEncodingTrait>(type.getEncoding ()),
200+ type.getShape ());
201+ }
169202
170203CTALayoutAttr getCTALayout (Attribute layout);
171204
0 commit comments