Skip to content

Commit 6b99232

Browse files
authored
Fix rocmlir-gen device selection (#1964)
1 parent 97a0085 commit 6b99232

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

mlir/tools/rocmlir-gen/rocmlir-gen.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,16 +1362,37 @@ static func::FuncOp createGPUWrapper(ModuleOp module,
13621362
func::FuncOp::create(loc, StringRef(funcNameGpu), gpuWrapperFuncType);
13631363
module.push_back(gpuWrapperFunc);
13641364

1365+
// Emit device selection
1366+
if (deviceNum.getNumOccurrences() > 0) {
1367+
const int32_t priority = 122;
1368+
const StringRef constructorName = "setDeviceCtor";
1369+
auto func = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(constructorName);
1370+
if (!func) {
1371+
func = b.create<mlir::LLVM::LLVMFuncOp>(
1372+
module.getLoc(), constructorName,
1373+
mlir::LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
1374+
{}));
1375+
module.push_back(func);
1376+
1377+
Block *block = func.addEntryBlock(b);
1378+
b.setInsertionPoint(block, block->begin());
1379+
b.create<gpu::SetDefaultDeviceOp>(
1380+
loc, b.create<arith::ConstantIntOp>(loc, b.getIntegerType(32),
1381+
deviceNum.getValue()));
1382+
b.create<mlir::LLVM::ReturnOp>(loc, ValueRange{});
1383+
1384+
b.setInsertionPointToEnd(module.getBody());
1385+
b.create<mlir::LLVM::GlobalCtorsOp>(
1386+
loc, b.getArrayAttr(mlir::SymbolRefAttr::get(func)),
1387+
b.getI32ArrayAttr({priority}),
1388+
b.getArrayAttr(mlir::LLVM::ZeroAttr::get(context)));
1389+
}
1390+
}
1391+
13651392
// Emit gpu convolution logic.
13661393
Block *block = gpuWrapperFunc.addEntryBlock();
13671394
b.setInsertionPoint(block, block->begin());
13681395

1369-
// Emit device selection
1370-
if (deviceNum.getNumOccurrences() > 0)
1371-
b.create<gpu::SetDefaultDeviceOp>(
1372-
loc, b.create<arith::ConstantIntOp>(loc, b.getIntegerType(32),
1373-
deviceNum.getValue()));
1374-
13751396
SmallVector<Value, 4> cpuMem;
13761397
SmallVector<Value, 4> gpuMem;
13771398
for (auto pair : llvm::enumerate(kernels[0].params)) {
@@ -5006,7 +5027,8 @@ int main(int argc, char **argv) {
50065027
math::MathDialect, arith::ArithDialect,
50075028
vector::VectorDialect, gpu::GPUDialect,
50085029
linalg::LinalgDialect, mhal::MHALDialect,
5009-
bufferization::BufferizationDialect, tosa::TosaDialect>();
5030+
bufferization::BufferizationDialect, tosa::TosaDialect,
5031+
mlir::LLVM::LLVMDialect>();
50105032

50115033
// Parse pass names in main to ensure static initialization completed.
50125034
llvm::cl::ParseCommandLineOptions(argc, argv,

0 commit comments

Comments
 (0)