@@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
134134 mod.walk ([&](mlir::Operation *op) {
135135 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
136136 if (!hasPortableSignature (call.getFunctionType (), op))
137- convertCallOp (call);
137+ convertCallOp (call, call. getFunctionType () );
138138 } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
139139 if (!hasPortableSignature (dispatch.getFunctionType (), op))
140- convertCallOp (dispatch);
140+ convertCallOp (dispatch, dispatch.getFunctionType ());
141+ } else if (auto gpuLaunchFunc =
142+ mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
143+ llvm::SmallVector<mlir::Type> operandsTypes;
144+ for (auto arg : gpuLaunchFunc.getKernelOperands ())
145+ operandsTypes.push_back (arg.getType ());
146+ auto fctTy = mlir::FunctionType::get (&context, operandsTypes, {});
147+ if (!hasPortableSignature (fctTy, op))
148+ convertCallOp (gpuLaunchFunc, fctTy);
141149 } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
142150 if (mlir::isa<mlir::FunctionType>(addr.getType ()) &&
143151 !hasPortableSignature (addr.getType (), op))
@@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
357365
358366 // Convert fir.call and fir.dispatch Ops.
359367 template <typename A>
360- void convertCallOp (A callOp) {
361- auto fnTy = callOp.getFunctionType ();
368+ void convertCallOp (A callOp, mlir::FunctionType fnTy) {
362369 auto loc = callOp.getLoc ();
363370 rewriter->setInsertionPoint (callOp);
364371 llvm::SmallVector<mlir::Type> newResTys;
@@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
376383 newOpers.push_back (callOp.getOperand (0 ));
377384 dropFront = 1 ;
378385 }
379- } else {
386+ } else if constexpr (std::is_same_v<std:: decay_t <A>, fir::DispatchOp>) {
380387 dropFront = 1 ; // First operand is the polymorphic object.
381388 }
382389
@@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
402409
403410 llvm::SmallVector<mlir::Type> trailingInTys;
404411 llvm::SmallVector<mlir::Value> trailingOpers;
412+ llvm::SmallVector<mlir::Value> operands;
405413 unsigned passArgShift = 0 ;
414+ if constexpr (std::is_same_v<std::decay_t <A>, mlir::gpu::LaunchFuncOp>)
415+ operands = callOp.getKernelOperands ();
416+ else
417+ operands = callOp.getOperands ().drop_front (dropFront);
406418 for (auto e : llvm::enumerate (
407- llvm::zip (fnTy.getInputs ().drop_front (dropFront),
408- callOp.getOperands ().drop_front (dropFront)))) {
419+ llvm::zip (fnTy.getInputs ().drop_front (dropFront), operands))) {
409420 mlir::Type ty = std::get<0 >(e.value ());
410421 mlir::Value oper = std::get<1 >(e.value ());
411422 unsigned index = e.index ();
@@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
507518 newOpers.insert (newOpers.end (), trailingOpers.begin (), trailingOpers.end ());
508519
509520 llvm::SmallVector<mlir::Value, 1 > newCallResults;
510- if constexpr (std::is_same_v<std::decay_t <A>, fir::CallOp>) {
521+ if constexpr (std::is_same_v<std::decay_t <A>, mlir::gpu::LaunchFuncOp>) {
522+ auto newCall = rewriter->create <A>(
523+ loc, callOp.getKernel (), callOp.getGridSizeOperandValues (),
524+ callOp.getBlockSizeOperandValues (),
525+ callOp.getDynamicSharedMemorySize (), newOpers);
526+ if (callOp.getClusterSizeX ())
527+ newCall.getClusterSizeXMutable ().assign (callOp.getClusterSizeX ());
528+ if (callOp.getClusterSizeY ())
529+ newCall.getClusterSizeYMutable ().assign (callOp.getClusterSizeY ());
530+ if (callOp.getClusterSizeZ ())
531+ newCall.getClusterSizeZMutable ().assign (callOp.getClusterSizeZ ());
532+ newCallResults.append (newCall.result_begin (), newCall.result_end ());
533+ } else if constexpr (std::is_same_v<std::decay_t <A>, fir::CallOp>) {
511534 fir::CallOp newCall;
512535 if (callOp.getCallee ()) {
513536 newCall =
0 commit comments