@@ -175,6 +175,7 @@ class LaunchKernel {
175175 IRBuilderBase &builder;
176176 mlir::LLVM::ModuleTranslation &moduleTranslation;
177177 Type *i32Ty{};
178+ Type *i64Ty{};
178179 Type *voidTy{};
179180 Type *intPtrTy{};
180181 PointerType *ptrTy{};
@@ -216,6 +217,7 @@ llvm::LaunchKernel::LaunchKernel(
216217 mlir::LLVM::ModuleTranslation &moduleTranslation)
217218 : module(module ), builder(builder), moduleTranslation(moduleTranslation) {
218219 i32Ty = builder.getInt32Ty ();
220+ i64Ty = builder.getInt64Ty ();
219221 ptrTy = builder.getPtrTy (0 );
220222 voidTy = builder.getVoidTy ();
221223 intPtrTy = builder.getIntPtrTy (module .getDataLayout ());
@@ -224,11 +226,11 @@ llvm::LaunchKernel::LaunchKernel(
224226llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn () {
225227 return module .getOrInsertFunction (
226228 " mgpuLaunchKernel" ,
227- FunctionType::get (
228- voidTy ,
229- ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy ,
230- intPtrTy, intPtrTy, i32Ty, ptrTy, ptrTy, ptrTy}),
231- false ));
229+ FunctionType::get (voidTy,
230+ ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy ,
231+ intPtrTy, intPtrTy, intPtrTy, i32Ty ,
232+ ptrTy, ptrTy, ptrTy, i64Ty }),
233+ false ));
232234}
233235
234236llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn () {
@@ -251,7 +253,7 @@ llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
251253llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn () {
252254 return module .getOrInsertFunction (
253255 " mgpuModuleLoad" ,
254- FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy}), false ));
256+ FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy, i64Ty }), false ));
255257}
256258
257259llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn () {
@@ -391,10 +393,24 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
391393 if (!binary)
392394 return op.emitError () << " Couldn't find the binary: " << binaryIdentifier;
393395
396+ auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
397+ if (!binaryVar)
398+ return op.emitError () << " Binary is not a global variable: "
399+ << binaryIdentifier;
400+ llvm::Constant *binaryInit = binaryVar->getInitializer ();
401+ auto binaryDataSeq =
402+ dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
403+ if (!binaryDataSeq)
404+ return op.emitError () << " Couldn't find binary data array: "
405+ << binaryIdentifier;
406+ llvm::Constant *binarySize =
407+ llvm::ConstantInt::get (i64Ty, binaryDataSeq->getNumElements () *
408+ binaryDataSeq->getElementByteSize ());
409+
394410 Value *moduleObject =
395411 object.getFormat () == gpu::CompilationTarget::Assembly
396412 ? builder.CreateCall (getModuleLoadJITFn (), {binary, optV})
397- : builder.CreateCall (getModuleLoadFn (), {binary});
413+ : builder.CreateCall (getModuleLoadFn (), {binary, binarySize });
398414
399415 // Load the kernel function.
400416 Value *moduleFunction = builder.CreateCall (
@@ -413,6 +429,9 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
413429 stream = builder.CreateCall (getStreamCreateFn (), {});
414430 }
415431
432+ llvm::Constant *paramsCount =
433+ llvm::ConstantInt::get (i64Ty, op.getNumKernelOperands ());
434+
416435 // Create the launch call.
417436 Value *nullPtr = ConstantPointerNull::get (ptrTy);
418437
@@ -426,10 +445,10 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
426445 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
427446 dynamicMemorySize, stream, argArray, nullPtr}));
428447 } else {
429- builder.CreateCall (
430- getKernelLaunchFn () ,
431- ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz ,
432- dynamicMemorySize, stream, argArray, nullPtr}));
448+ builder.CreateCall (getKernelLaunchFn (),
449+ ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by ,
450+ bz, dynamicMemorySize, stream ,
451+ argArray, nullPtr, paramsCount }));
433452 }
434453
435454 // Sync & destroy the stream, for synchronous launches.
0 commit comments