@@ -1362,16 +1362,37 @@ static func::FuncOp createGPUWrapper(ModuleOp module,
13621362 func::FuncOp::create (loc, StringRef (funcNameGpu), gpuWrapperFuncType);
13631363 module .push_back (gpuWrapperFunc);
13641364
1365+ // Emit device selection
1366+ if (deviceNum.getNumOccurrences () > 0 ) {
1367+ const int32_t priority = 122 ;
1368+ const StringRef constructorName = " setDeviceCtor" ;
1369+ auto func = module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(constructorName);
1370+ if (!func) {
1371+ func = b.create <mlir::LLVM::LLVMFuncOp>(
1372+ module .getLoc (), constructorName,
1373+ mlir::LLVM::LLVMFunctionType::get (LLVM::LLVMVoidType::get (context),
1374+ {}));
1375+ module .push_back (func);
1376+
1377+ Block *block = func.addEntryBlock (b);
1378+ b.setInsertionPoint (block, block->begin ());
1379+ b.create <gpu::SetDefaultDeviceOp>(
1380+ loc, b.create <arith::ConstantIntOp>(loc, b.getIntegerType (32 ),
1381+ deviceNum.getValue ()));
1382+ b.create <mlir::LLVM::ReturnOp>(loc, ValueRange{});
1383+
1384+ b.setInsertionPointToEnd (module .getBody ());
1385+ b.create <mlir::LLVM::GlobalCtorsOp>(
1386+ loc, b.getArrayAttr (mlir::SymbolRefAttr::get (func)),
1387+ b.getI32ArrayAttr ({priority}),
1388+ b.getArrayAttr (mlir::LLVM::ZeroAttr::get (context)));
1389+ }
1390+ }
1391+
13651392 // Emit gpu convolution logic.
13661393 Block *block = gpuWrapperFunc.addEntryBlock ();
13671394 b.setInsertionPoint (block, block->begin ());
13681395
1369- // Emit device selection
1370- if (deviceNum.getNumOccurrences () > 0 )
1371- b.create <gpu::SetDefaultDeviceOp>(
1372- loc, b.create <arith::ConstantIntOp>(loc, b.getIntegerType (32 ),
1373- deviceNum.getValue ()));
1374-
13751396 SmallVector<Value, 4 > cpuMem;
13761397 SmallVector<Value, 4 > gpuMem;
13771398 for (auto pair : llvm::enumerate (kernels[0 ].params )) {
@@ -5006,7 +5027,8 @@ int main(int argc, char **argv) {
50065027 math::MathDialect, arith::ArithDialect,
50075028 vector::VectorDialect, gpu::GPUDialect,
50085029 linalg::LinalgDialect, mhal::MHALDialect,
5009- bufferization::BufferizationDialect, tosa::TosaDialect>();
5030+ bufferization::BufferizationDialect, tosa::TosaDialect,
5031+ mlir::LLVM::LLVMDialect>();
50105032
50115033 // Parse pass names in main to ensure static initialization completed.
50125034 llvm::cl::ParseCommandLineOptions (argc, argv,
0 commit comments