2424#include " gc/Utils/Log.h"
2525
2626using namespace mlir ;
27- // using namespace mlir::gc::gpu ;
27+ using namespace mlir ::gc;
2828
2929namespace mlir ::gc {
3030#define GEN_PASS_DECL_GPUTILINGANDFUSION
@@ -39,33 +39,34 @@ struct GpuTilingAndFusion final
3939 gc::impl::GpuTilingAndFusionBase<GpuTilingAndFusion> {
4040 friend struct GpuPass ;
4141 explicit GpuTilingAndFusion ()
42- : GpuTilingAndFusion(gc:: GpuTilingAndFusionOptions{}) {}
43- explicit GpuTilingAndFusion (const gc:: GpuTilingAndFusionOptions &opts)
42+ : GpuTilingAndFusion(GpuTilingAndFusionOptions{}) {}
43+ explicit GpuTilingAndFusion (const GpuTilingAndFusionOptions &opts)
4444 : GpuPass(), GpuTilingAndFusionBase(opts) {}
4545
4646 void runOnOperation () override {
4747 IRRewriter rewriter (&getContext ());
4848 scf::SCFTileAndFuseOptions opts;
49+ opts.setFusionControlFn (
50+ [&](tensor::ExtractSliceOp, OpResult originalProducer, bool )
51+ -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
52+ Operation *op = originalProducer.getOwner ();
53+ if (!op) {
54+ return std::nullopt ;
55+ }
56+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
57+ if (!linalgOp.hasOnlyProjectedPermutations ()) {
58+ return std::nullopt ;
59+ }
60+ }
61+ return scf::SCFTileAndFuseOptions::ControlFnResult{};
62+ });
4963 opts.tilingOptions .setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
5064 // The outer loop is converted to a GPU kernel and the tile sizes are mapped
5165 // to the grid sizes.
5266 opts.tilingOptions .setTileSizeComputationFunction (
53- // The tile sizes calculation is based on the following equation:
54- // n * TS0 * TS1 * ... * TSn = euMem
55- // where:
56- // n - an average number of bytes, processed by each iteration
57- // TS0, TS1, ... TSn - the tile sizes for each loop correspondingly
58- // euMem - the physical memory (cache) size of the GPU execution unit
59- //
60- // To calculate the tile size TS, we need to divide the total loop size
61- // S by the ratio r:
62- //
63- // n * (S0/r0) * (S1/r1) * ... * (Sn/rn) = euMem
64- // r0 * r1 * ... * rn = (n * S0 * S1 * ... * Sn) / euMem
65- // If all sizes are equal, then S0 = ... = Sn = S, r0 = ... = rn = r:
66- // r^n = (n * S^n) / euMem
67- // r = (n * S^n / euMem)^(1/n)
68- [euMem = getEuMem (rewriter), euThreads = getEuThreads (rewriter)](
67+ [numThreads = getNumThreads (rewriter),
68+ cacheSize = getCacheSize (rewriter),
69+ vectorWidth = getVectorWidth (rewriter)](
6970 OpBuilder &builder, Operation *op) -> SmallVector<OpFoldResult> {
7071 auto ti = dyn_cast<TilingInterface>(op);
7172 if (!ti) {
@@ -76,44 +77,47 @@ struct GpuTilingAndFusion final
7677 auto itDomains = ti.getIterationDomain (builder);
7778 assert (itTypes.size () == itDomains.size ());
7879
79- // TODO: Add a parameter to the options?
80- size_t totalSize = calcOperandsSize (op) * euThreads;
81- unsigned loopCount = 0 ;
82-
80+ SmallVector<int64_t > tiles;
8381 for (auto [t, r] : zip (itTypes, itDomains)) {
8482 if (t == utils::IteratorType::parallel) {
8583 if (auto v = getConstantIntValue (r.size )) {
86- loopCount++;
87- totalSize *= *v;
84+ tiles.emplace_back (*v);
8885 } else {
89- return calcDynamicSizes (builder, ti, euMem, euThreads );
86+ return calcDynamicSizes (builder, ti, cacheSize );
9087 }
9188 }
9289 }
9390
94- if (loopCount == 0 ) {
91+ if (tiles. empty () ) {
9592 return {};
9693 }
9794
98- // TODO: In case of different sizes, calculate the ratio for each loop
99- double ratio = std::pow (static_cast <double >(totalSize) /
100- static_cast <double >(euMem),
101- 1.0 / loopCount);
102- ratio = std::max (1.0 , ratio);
103- SmallVector<OpFoldResult> tiles;
104- tiles.reserve (itDomains.size ());
95+ int64_t elementSize = getElementSize (op);
96+ int64_t totalSize;
97+
98+ // If the operation could be lowered to XeGPU, make the tiles
99+ // proportional to the vector width. Otherwise, use the cache size.
100+ if (canLowerToXeGPU (op)) {
101+ totalSize = vectorWidth * numThreads / elementSize;
102+ } else {
103+ totalSize = cacheSize / elementSize;
104+ }
105+
106+ adjustTiles (totalSize, tiles);
107+
108+ unsigned counter = 0 ;
109+ SmallVector<OpFoldResult> result;
110+ result.reserve (itDomains.size ());
105111
106112 for (auto [t, r] : zip (itTypes, itDomains)) {
107113 if (t != utils::IteratorType::parallel) {
108- tiles.emplace_back (builder.getIndexAttr (1 ));
109- } else if (auto v = getConstantIntValue (r.size )) {
110- tiles.emplace_back (ceil (builder, *v, ratio));
114+ result.emplace_back (builder.getIndexAttr (1 ));
111115 } else {
112- abort (); // Must never get here
116+ result. emplace_back (builder. getIndexAttr (tiles[counter++]));
113117 }
114118 }
115119
116- return tiles ;
120+ return result ;
117121 });
118122
119123 auto fn = getOperation ();
@@ -174,7 +178,8 @@ struct GpuTilingAndFusion final
174178 static std::optional<TilingInterface> findTi (Operation *op) {
175179 std::optional<TilingInterface> last;
176180 op->walk <WalkOrder::PreOrder>([&](linalg::LinalgOp linalgOp) {
177- if (!linalgOp->getParentOfType <scf::ForallOp>()) {
181+ if (linalgOp.hasOnlyProjectedPermutations () &&
182+ !linalgOp->getParentOfType <scf::ForallOp>()) {
178183 if (auto ti = dyn_cast<TilingInterface>(linalgOp.getOperation ())) {
179184 last = ti;
180185 }
@@ -184,17 +189,16 @@ struct GpuTilingAndFusion final
184189 return last;
185190 }
186191
187- static SmallVector<OpFoldResult> calcDynamicSizes (OpBuilder &builder,
188- TilingInterface ti,
189- size_t euMem,
190- size_t euThreads) {
192+ // TODO: Use the adjustTiles() function from MLIR.
193+ static SmallVector<OpFoldResult>
194+ calcDynamicSizes (OpBuilder &builder, TilingInterface ti, int64_t cacheSize) {
191195 auto itTypes = ti.getLoopIteratorTypes ();
192196 auto itDomains = ti.getIterationDomain (builder);
193197 assert (itTypes.size () == itDomains.size ());
194198
195199 auto loc = ti.getLoc ();
196200 Value dynamicSize;
197- size_t staticSize = calcOperandsSize (ti.getOperation ()) * euThreads ;
201+ int64_t staticSize = getElementSize (ti.getOperation ());
198202 unsigned loopCount = 0 ;
199203
200204 for (auto [t, r] : zip (itTypes, itDomains)) {
@@ -225,7 +229,7 @@ struct GpuTilingAndFusion final
225229 dynamicSize));
226230
227231 auto memSize = builder.create <arith::ConstantFloatOp>(
228- loc, APFloat (static_cast <double >(euMem )), builder.getF64Type ());
232+ loc, APFloat (static_cast <double >(cacheSize )), builder.getF64Type ());
229233 auto pow = builder.create <arith::ConstantFloatOp>(
230234 loc, APFloat (1.0 / loopCount), builder.getF64Type ());
231235 Value ratio = builder.create <math::PowFOp>(
@@ -265,29 +269,63 @@ struct GpuTilingAndFusion final
265269 return tiles;
266270 }
267271
268- static size_t calcOperandsSize (Operation *op) {
269- size_t size = 0 ;
270- auto typeSize = [](Type t) -> size_t {
271- Type et;
272- if (auto mt = dyn_cast<MemRefType>(t)) {
273- et = mt.getElementType ();
274- } else if (auto tt = dyn_cast<TensorType>(t)) {
275- et = tt.getElementType ();
272+ static int64_t getElementSize (Operation *op) {
273+ int64_t elementSize = 1 ;
274+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
275+ if (auto inits = linalgOp.getDpsInits (); !inits.empty ()) {
276+ if (auto t = getElementTypeOrSelf (inits[0 ].getType ());
277+ t.isIntOrFloat ()) {
278+ elementSize = t.getIntOrFloatBitWidth () / 8 ;
279+ }
280+ }
281+ }
282+ return elementSize;
283+ }
284+
285+ // TODO: Add more checks
286+ static bool canLowerToXeGPU (Operation *operation) {
287+ auto op = dyn_cast<linalg::LinalgOp>(operation);
288+ if (!op) {
289+ return false ;
290+ }
291+ if (op.hasDynamicShape ()) {
292+ return false ;
293+ }
294+
295+ auto checkOperand = [&](Value operand, bool isOutput = false ) {
296+ ShapedType type;
297+ if (auto memref = dyn_cast<MemRefType>(operand.getType ())) {
298+ type = memref;
299+ } else if (auto tensor = dyn_cast<RankedTensorType>(operand.getType ())) {
300+ type = tensor;
276301 } else {
277- return 0 ;
302+ return false ;
278303 }
279- return et.isIntOrFloat () ? et.getIntOrFloatBitWidth () / 8 : 1 ;
280- };
281- for (auto operand : op->getOperands ()) {
282- if (auto defOp = operand.getDefiningOp ()) {
283- for (auto t : defOp->getResultTypes ()) {
284- size += typeSize (t);
304+
305+ auto shape = type.getShape ();
306+ if (isOutput) {
307+ if (shape.size () != 2 || shape[0 ] * shape[1 ] < 16 ) {
308+ return false ;
285309 }
286- } else {
287- size += typeSize (operand. getType ()) ;
310+ } else if (shape. size () > 2 ) {
311+ return false ;
288312 }
313+
314+ return true ;
315+ };
316+
317+ if (auto inits = op.getDpsInits ();
318+ !inits.empty () && !checkOperand (inits[0 ], true )) {
319+ return false ;
289320 }
290- return size == 0 ? 1 : size;
321+
322+ if (auto inputs = op.getDpsInputs ();
323+ !std::all_of (inputs.begin (), inputs.end (),
324+ [&](Value v) { return checkOperand (v); })) {
325+ return false ;
326+ }
327+
328+ return true ;
291329 }
292330};
293331} // namespace
0 commit comments