@@ -196,55 +196,6 @@ LinearLayout shrinkCodomain(const LinearLayout &layout, StringAttr inDimName,
196196 return layout.compose (transform);
197197}
198198
199- // Combines the layout of a CTA (input dims [register, lane, warp]) with the
200- // layout of a CGA (i.e. a block), and ensures that the resulting layout has the
201- // given shape.
202- //
203- // See the nomenclature note at the top of the file for why the variable with
204- // type CTALayoutAttr is called cgaLayoutAttr.
205- LinearLayout combineCtaCgaWithShape (LinearLayout ctaLayout,
206- CTALayoutAttr cgaLayoutAttr,
207- ArrayRef<int64_t > shape) {
208- int rank = shape.size ();
209- assert (ctaLayout.getNumOutDims () == rank &&
210- " ctaLayout must have the same rank as shape" );
211- assert (cgaLayoutAttr.getCTAOrder ().size () == rank &&
212- " cgaLayoutAttr must have the same rank as shape" );
213- MLIRContext *ctx = cgaLayoutAttr.getContext ();
214-
215- SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
216-
217- llvm::SmallDenseMap<StringAttr, int64_t > labeledShape;
218- for (auto [dim, size] : llvm::zip (outDimNames, shape)) {
219- labeledShape[dim] = size;
220- }
221-
222- LinearLayout cgaLayout =
223- ensureLayoutNotLargerThan (makeCgaLayout (cgaLayoutAttr), labeledShape)
224- .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
225-
226- // Calculate the shape of the ctaLayout, which is `shape` divided by the
227- // cgaLayout's size.
228- llvm::SmallDenseMap<StringAttr, int64_t > ctaShape;
229- assert (llvm::to_vector (ctaLayout.getOutDimNames ()) ==
230- llvm::to_vector (cgaLayout.getOutDimNames ()) &&
231- " bad layout" );
232- for (auto dim : ctaLayout.getOutDimNames ()) {
233- ctaShape[dim] =
234- std::max (int64_t {1 }, labeledShape[dim] / cgaLayout.getOutDimSize (dim));
235- }
236-
237- ctaLayout = ensureLayoutNotSmallerThan (ctaLayout, ctaShape);
238- ctaLayout = ensureLayoutNotLargerThan (ctaLayout, ctaShape);
239-
240- LinearLayout ret =
241- (std::move (ctaLayout) * std::move (cgaLayout)).transposeOuts (outDimNames);
242- for (auto dim : ret.getOutDimNames ()) {
243- assert (ret.getOutDimSize (dim) == labeledShape[dim] && " bad shape" );
244- }
245- return ret;
246- }
247-
248199} // anonymous namespace
249200
250201// clang-format off
0 commit comments