Skip to content

Commit f1140b7

Browse files
author
Mei, Yijie
committed
Rebase patch: Modify insert-gpu-allocs pass to also consider xegpu.store/load
1 parent ba6721e commit f1140b7

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

lib/Transforms/InsertGPUAllocs.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <mlir/Dialect/Func/IR/FuncOps.h>
2929
#include <mlir/Dialect/GPU/Transforms/Passes.h>
3030
#include <mlir/Dialect/MemRef/IR/MemRef.h>
31+
#include <mlir/Dialect/XeGPU/IR/XeGPU.h>
3132
#include <mlir/Dialect/SCF/IR/SCF.h>
3233
#include <mlir/Pass/Pass.h>
3334

@@ -158,6 +159,9 @@ class InsertGPUAllocsPass final
158159
} else if (auto init_tile =
159160
mlir::dyn_cast<imex::xetile::InitTileOp>(op)) {
160161
return {{init_tile.getSource()}};
162+
} else if (auto init_xedesc =
163+
mlir::dyn_cast<mlir::xegpu::CreateNdDescOp>(op)) {
164+
return {{init_xedesc.getSource()}};
161165
} else {
162166
op->emitError("Uhhandled mem op in gpu region");
163167
return std::nullopt;
@@ -187,6 +191,9 @@ class InsertGPUAllocsPass final
187191
// Only handle the case where the tile source is a memref
188192
return init_tile.isSourceMemRef();
189193
}
194+
if (auto init_xedesc = mlir::dyn_cast<mlir::xegpu::CreateNdDescOp>(op)) {
195+
return true;
196+
}
190197
return false;
191198
};
192199

@@ -259,6 +266,36 @@ class InsertGPUAllocsPass final
259266
return;
260267
}
261268

269+
// walk over the users and find xegpu.load/store ops
270+
std::function<void(mlir::Operation*, bool, AccessType&)> findXeGPULoadStore;
271+
findXeGPULoadStore = [&](mlir::Operation *use, bool onDevice, AccessType& ret) {
272+
if (auto tile_update = mlir::dyn_cast<mlir::xegpu::UpdateNdOffsetOp>(use)) {
273+
auto res = tile_update->getResult(0);
274+
for (auto u : res.getUsers()) {
275+
findXeGPULoadStore(u, onDevice, ret);
276+
}
277+
}
278+
if (auto tile_for = mlir::dyn_cast<::mlir::scf::ForOp>(use)) {
279+
for (size_t idx=0; idx<tile_for.getInits().size(); idx++) {
280+
auto a = tile_for.getRegionIterArg(idx);
281+
for (auto u : a.getUsers()) {
282+
findXeGPULoadStore(u, onDevice, ret);
283+
}
284+
}
285+
}
286+
if (auto tile_load =
287+
mlir::dyn_cast<mlir::xegpu::LoadNdOp>(use)) {
288+
(onDevice ? ret.deviceRead : ret.hostRead) = true;
289+
}
290+
else if (auto tile_prefetch =
291+
mlir::dyn_cast<mlir::xegpu::PrefetchNdOp>(use)) {
292+
(onDevice ? ret.deviceRead : ret.hostRead) = true;
293+
} else if (auto tile_store =
294+
mlir::dyn_cast<mlir::xegpu::StoreNdOp>(use)) {
295+
(onDevice ? ret.deviceWrite : ret.hostWrite) = true;
296+
}
297+
};
298+
262299
// Checks the access type of the OP under consideration.
263300
auto getAccessType = [&](mlir::Value memref) {
264301
AccessType ret;
@@ -298,6 +335,15 @@ class InsertGPUAllocsPass final
298335
continue;
299336
}
300337

338+
if (auto init_xedesc = mlir::dyn_cast<mlir::xegpu::CreateNdDescOp>(user)) {
339+
bool onDevice = user->getParentOfType<mlir::gpu::LaunchOp>();
340+
auto res = init_xedesc->getResult(0);
341+
for (auto use : res.getUsers()) {
342+
findXeGPULoadStore(use, onDevice, ret);
343+
}
344+
continue;
345+
}
346+
301347
if (mlir::isa<mlir::func::ReturnOp>(user)) {
302348
ret.hostRead = true;
303349
ret.hostWrite = true;

0 commit comments

Comments
 (0)