@@ -46,6 +46,21 @@ struct GpuTilingAndFusion final
4646 void runOnOperation () override {
4747 IRRewriter rewriter (&getContext ());
4848 scf::SCFTileAndFuseOptions opts;
49+ opts.setFusionControlFn (
50+ [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
51+ bool isDestinationOperand)
52+ -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
53+ Operation *op = originalProducer.getOwner ();
54+ if (!op) {
55+ return std::nullopt ;
56+ }
57+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
58+ if (!linalgOp.hasOnlyProjectedPermutations ()) {
59+ return std::nullopt ;
60+ }
61+ }
62+ return scf::SCFTileAndFuseOptions::ControlFnResult{};
63+ });
4964 opts.tilingOptions .setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
5065 // The outer loop is converted to a GPU kernel and the tile sizes are mapped
5166 // to the grid sizes.
@@ -77,13 +92,15 @@ struct GpuTilingAndFusion final
7792 assert (itTypes.size () == itDomains.size ());
7893
7994 // TODO: Add a parameter to the options?
80- size_t totalSize = calcOperandsSize (op) * euThreads ;
95+ size_t totalSize = calcOperandsSize (op);
8196 unsigned loopCount = 0 ;
97+ SmallVector<int64_t > sizes;
8298
8399 for (auto [t, r] : zip (itTypes, itDomains)) {
84100 if (t == utils::IteratorType::parallel) {
85101 if (auto v = getConstantIntValue (r.size )) {
86102 loopCount++;
103+ sizes.emplace_back (*v);
87104 totalSize *= *v;
88105 } else {
89106 return calcDynamicSizes (builder, ti, euMem, euThreads);
@@ -95,19 +112,25 @@ struct GpuTilingAndFusion final
95112 return {};
96113 }
97114
98- // TODO: In case of different sizes, calculate the ratio for each loop
99- double ratio = std::pow (static_cast <double >(totalSize) /
100- static_cast <double >(euMem),
101- 1.0 / loopCount);
102- ratio = std::max (1.0 , ratio);
115+ auto outerTileSize = static_cast <size_t >(
116+ std::ceil (static_cast <double >(euMem) /
117+ static_cast <double >(calcOperandsSize (op))));
118+ SmallVector<int64_t > outerTiles;
119+ SmallVector<int64_t > innerTiles;
120+ normaliseTiles (outerTileSize, sizes, outerTiles);
121+ normaliseTiles (euThreads, sizes, innerTiles);
122+
123+ unsigned counter = 0 ;
103124 SmallVector<OpFoldResult> tiles;
104125 tiles.reserve (itDomains.size ());
105126
106127 for (auto [t, r] : zip (itTypes, itDomains)) {
107128 if (t != utils::IteratorType::parallel) {
108129 tiles.emplace_back (builder.getIndexAttr (1 ));
109130 } else if (auto v = getConstantIntValue (r.size )) {
110- tiles.emplace_back (ceil (builder, *v, ratio));
131+ tiles.emplace_back (
132+ ceil (builder, outerTiles[counter], innerTiles[counter]));
133+ counter++;
111134 } else {
112135 abort (); // Must never get here
113136 }
@@ -174,7 +197,8 @@ struct GpuTilingAndFusion final
174197 static std::optional<TilingInterface> findTi (Operation *op) {
175198 std::optional<TilingInterface> last;
176199 op->walk <WalkOrder::PreOrder>([&](linalg::LinalgOp linalgOp) {
177- if (!linalgOp->getParentOfType <scf::ForallOp>()) {
200+ if (linalgOp.hasOnlyProjectedPermutations () &&
201+ !linalgOp->getParentOfType <scf::ForallOp>()) {
178202 if (auto ti = dyn_cast<TilingInterface>(linalgOp.getOperation ())) {
179203 last = ti;
180204 }
0 commit comments