@@ -46,26 +46,27 @@ struct GpuTilingAndFusion final
4646 void runOnOperation () override {
4747 IRRewriter rewriter (&getContext ());
4848 scf::SCFTileAndFuseOptions opts;
49+ opts.setFusionControlFn (
50+ [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
51+ bool isDestinationOperand)
52+ -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
53+ Operation *op = originalProducer.getOwner ();
54+ if (!op) {
55+ return std::nullopt ;
56+ }
57+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
58+ if (!linalgOp.hasOnlyProjectedPermutations ()) {
59+ return std::nullopt ;
60+ }
61+ }
62+ return scf::SCFTileAndFuseOptions::ControlFnResult{};
63+ });
4964 opts.tilingOptions .setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
5065 // The outer loop is converted to a GPU kernel and the tile sizes are mapped
5166 // to the grid sizes.
5267 opts.tilingOptions .setTileSizeComputationFunction (
53- // The tile sizes calculation is based on the following equation:
54- // n * TS0 * TS1 * ... * TSn = euMem
55- // where:
56- // n - an average number of bytes, processed by each iteration
57- // TS0, TS1, ... TSn - the tile sizes for each loop correspondingly
58- // euMem - the physical memory (cache) size of the GPU execution unit
59- //
60- // To calculate the tile size TS, we need to divide the total loop size
61- // S by the ratio r:
62- //
63- // n * (S0/r0) * (S1/r1) * ... * (Sn/rn) = euMem
64- // r0 * r1 * ... * rn = (n * S0 * S1 * ... * Sn) / euMem
65- // If all sizes are equal, then S0 = ... = Sn = S, r0 = ... = rn = r:
66- // r^n = (n * S^n) / euMem
67- // r = (n * S^n / euMem)^(1/n)
68- [euMem = getEuMem (rewriter), euThreads = getEuThreads (rewriter)](
68+ [euMem = getEuMem (rewriter), euThreads = getEuThreads (rewriter),
69+ vectorWidth = getVectorWidth (rewriter)](
6970 OpBuilder &builder, Operation *op) -> SmallVector<OpFoldResult> {
7071 auto ti = dyn_cast<TilingInterface>(op);
7172 if (!ti) {
@@ -76,44 +77,45 @@ struct GpuTilingAndFusion final
7677 auto itDomains = ti.getIterationDomain (builder);
7778 assert (itTypes.size () == itDomains.size ());
7879
79- // TODO: Add a parameter to the options?
80- size_t totalSize = calcOperandsSize (op) * euThreads;
81- unsigned loopCount = 0 ;
82-
80+ SmallVector<int64_t > tiles;
8381 for (auto [t, r] : zip (itTypes, itDomains)) {
8482 if (t == utils::IteratorType::parallel) {
8583 if (auto v = getConstantIntValue (r.size )) {
86- loopCount++;
87- totalSize *= *v;
84+ tiles.emplace_back (*v);
8885 } else {
8986 return calcDynamicSizes (builder, ti, euMem, euThreads);
9087 }
9188 }
9289 }
9390
94- if (loopCount == 0 ) {
91+ if (tiles. empty () ) {
9592 return {};
9693 }
9794
98- // TODO: In case of different sizes, calculate the ratio for each loop
99- double ratio = std::pow (static_cast <double >(totalSize) /
100- static_cast <double >(euMem),
101- 1.0 / loopCount);
102- ratio = std::max (1.0 , ratio);
103- SmallVector<OpFoldResult> tiles;
104- tiles.reserve (itDomains.size ());
95+ size_t elementSize = 1 ;
96+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
97+ auto t = linalgOp.getDpsInits ()[0 ].getType ();
98+ if (t.isIntOrFloat ()) {
99+ elementSize = t.getIntOrFloatBitWidth () / 8 ;
100+ }
101+ }
102+ calcTiles (
103+ std::max (euThreads, euThreads / 2 * vectorWidth / elementSize),
104+ tiles);
105+
106+ unsigned counter = 0 ;
107+ SmallVector<OpFoldResult> result;
108+ result.reserve (itDomains.size ());
105109
106110 for (auto [t, r] : zip (itTypes, itDomains)) {
107111 if (t != utils::IteratorType::parallel) {
108- tiles.emplace_back (builder.getIndexAttr (1 ));
109- } else if (auto v = getConstantIntValue (r.size )) {
110- tiles.emplace_back (ceil (builder, *v, ratio));
112+ result.emplace_back (builder.getIndexAttr (1 ));
111113 } else {
112- abort (); // Must never get here
114+ result. emplace_back (builder. getIndexAttr (tiles[counter++]));
113115 }
114116 }
115117
116- return tiles ;
118+ return result ;
117119 });
118120
119121 auto fn = getOperation ();
@@ -174,7 +176,8 @@ struct GpuTilingAndFusion final
174176 static std::optional<TilingInterface> findTi (Operation *op) {
175177 std::optional<TilingInterface> last;
176178 op->walk <WalkOrder::PreOrder>([&](linalg::LinalgOp linalgOp) {
177- if (!linalgOp->getParentOfType <scf::ForallOp>()) {
179+ if (linalgOp.hasOnlyProjectedPermutations () &&
180+ !linalgOp->getParentOfType <scf::ForallOp>()) {
178181 if (auto ti = dyn_cast<TilingInterface>(linalgOp.getOperation ())) {
179182 last = ti;
180183 }
0 commit comments