Skip to content

Commit 1fb495e

Browse files
authored
Fix kernel generation when kernelRepeats are more than 1 (#1799)
Fixes issue described in ROCm/rocMLIR-internal#1803
1 parent b5afd2a commit 1fb495e

File tree

5 files changed

+100
-69
lines changed

5 files changed

+100
-69
lines changed

mlir/test/rocmlir-driver/populate_host.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@
5151
// CHECK-NEXT: vector.extractelement
5252
// CHECK-NEXT: memref.store %{{.*}}, %[[output]][%[[io]]] : memref<[[NGKHOWO]]x[[OTYPE]]>
5353
// CHECK-NEXT: }
54-
// CHECK-NEXT: call @rock_conv_gkc01_ngc01_ngk01_0_gpu({{.*}}, {{.*}}, {{.*}}) : (memref<[[GKCYX]]x[[TYPE]]>, memref<[[NGCHIWI]]x[[TYPE]]>, memref<[[NGKHOWO]]x[[OTYPE]]>) -> ()
54+
// CHECK-NEXT: call @rock_conv_gkc01_ngc01_ngk01_gpu({{.*}}, {{.*}}, {{.*}}) : (memref<[[GKCYX]]x[[TYPE]]>, memref<[[NGCHIWI]]x[[TYPE]]>, memref<[[NGKHOWO]]x[[OTYPE]]>) -> ()
5555
// CHECK-NEXT: memref.dealloc %[[filter]]
5656
// CHECK-NEXT: memref.dealloc %[[input]]
5757
// CHECK-NEXT: memref.dealloc %[[output]]
5858
// CHECK-NEXT: return
5959

60-
// CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_0_gpu(%{{.*}}: memref<[[GKCYX]]x[[TYPE]]>, %{{.*}}: memref<[[NGCHIWI]]x[[TYPE]]>, %{{.*}}: memref<[[NGKHOWO]]x[[OTYPE]]>)
60+
// CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_gpu(%{{.*}}: memref<[[GKCYX]]x[[TYPE]]>, %{{.*}}: memref<[[NGCHIWI]]x[[TYPE]]>, %{{.*}}: memref<[[NGKHOWO]]x[[OTYPE]]>)
6161
// CHECK-NEXT: gpu.alloc () : memref<[[GKCYX]]x[[TYPE]]>
6262
// CHECK-NEXT: gpu.memcpy %{{.*}}, %{{.*}} : memref<[[GKCYX]]x[[TYPE]]>, memref<[[GKCYX]]x[[TYPE]]>
6363
// CHECK-NEXT: gpu.alloc () : memref<[[NGCHIWI]]x[[TYPE]]>

mlir/test/rocmlir-driver/populate_host_splitk.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@
4545
// CHECK-NEXT: vector.extractelement
4646
// CHECK-NEXT: memref.store %{{.*}}, %[[output]][%[[io]]] : memref<[[NGKHOWO]]x[[OTYPE]]>
4747
// CHECK-NEXT: }
48-
// CHECK-NEXT: call @rock_conv_gkc01_ngc01_ngk01_0_gpu({{.*}}, {{.*}}, {{.*}}) : (memref<[[GKCYX]]x[[TYPE]]>, memref<[[NGCHIWI]]x[[TYPE]]>, memref<[[NGKHOWO]]x[[OTYPE]]>) -> ()
48+
// CHECK-NEXT: call @rock_conv_gkc01_ngc01_ngk01_gpu({{.*}}, {{.*}}, {{.*}}) : (memref<[[GKCYX]]x[[TYPE]]>, memref<[[NGCHIWI]]x[[TYPE]]>, memref<[[NGKHOWO]]x[[OTYPE]]>) -> ()
4949
// CHECK-NEXT: memref.dealloc %[[filter]]
5050
// CHECK-NEXT: memref.dealloc %[[input]]
5151
// CHECK-NEXT: memref.dealloc %[[output]]
5252
// CHECK-NEXT: return
5353

54-
// CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_0_gpu(%{{.*}}: memref<[[GKCYX]]x[[TYPE]]>, %{{.*}}: memref<[[NGCHIWI]]x[[TYPE]]>, %{{.*}}: memref<[[NGKHOWO]]x[[OTYPE]]>)
54+
// CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_gpu(%{{.*}}: memref<[[GKCYX]]x[[TYPE]]>, %{{.*}}: memref<[[NGCHIWI]]x[[TYPE]]>, %{{.*}}: memref<[[NGKHOWO]]x[[OTYPE]]>)
5555
// CHECK-NEXT: gpu.alloc () : memref<[[GKCYX]]x[[TYPE]]>
5656
// CHECK-NEXT: gpu.memcpy %{{.*}}, %{{.*}} : memref<[[GKCYX]]x[[TYPE]]>, memref<[[GKCYX]]x[[TYPE]]>
5757
// CHECK-NEXT: gpu.alloc () : memref<[[NGCHIWI]]x[[TYPE]]>

mlir/test/rocmlir-driver/populate_pv_with_gpu.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
// CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_0({{.*}}: memref<[[NFILTER:[0-9]+]]xf32>, {{.*}}: memref<[[NINPUT:[0-9]+]]xf32>, {{.*}}: memref<[[NOUTPUT:[0-9]+]]xf32>) attributes {kernel = 0 : i32, mhal.arch = "{{.*}}"} {
55
// CHECK: rock.conv({{.*}}) features = mfma|dot|atomic_add|atomic_add_f16 {[[PARMS:.*]]} : memref<[[FILTERDIMS:[x0-9]+]]xf32>, memref<[[INPUTDIMS:[x0-9]+]]xf32>, memref<[[OUTPUTDIMS:[x0-9]+]]xf32>
6-
// CHECK: call @rock_conv_gkc01_ngc01_ngk01_0_gpu({{.*}}) : (memref<[[NFILTER]]xf32>, memref<[[NINPUT]]xf32>, memref<[[NOUTPUT]]xf32>) -> ()
7-
// CHECK: call @rock_conv_gkc01_ngc01_ngk01_0_ver_gpu({{.*}}) : (memref<[[NFILTER]]xf32>, memref<[[NINPUT]]xf32>, memref<[[NOUTPUT]]xf32>) -> ()
6+
// CHECK: call @rock_conv_gkc01_ngc01_ngk01_gpu({{.*}}) : (memref<[[NFILTER]]xf32>, memref<[[NINPUT]]xf32>, memref<[[NOUTPUT]]xf32>) -> ()
7+
// CHECK: call @rock_conv_gkc01_ngc01_ngk01_ver_gpu({{.*}}) : (memref<[[NFILTER]]xf32>, memref<[[NINPUT]]xf32>, memref<[[NOUTPUT]]xf32>) -> ()
88
// CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_0_ver({{.*}}) attributes {kernel = 0 : i32, mhal.arch = "{{.*}}"} {
99
// CHECK: rock.conv({{.*}}) features = dot|atomic_add|atomic_add_f16 {{{.*}}} : memref<[[FILTERDIMS]]xf32>, memref<[[INPUTDIMS]]xf32>, memref<[[OUTPUTDIMS]]xf32>
1010

@@ -31,8 +31,8 @@
3131

3232
// F16-CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_0({{.*}}: memref<[[NFILTER:[0-9]+]]xf16>, {{.*}}: memref<[[NINPUT:[0-9]+]]xf16>, {{.*}}: memref<[[NOUTPUT:[0-9]+]]xf16>) attributes {kernel = 0 : i32, mhal.arch = "{{.*}}"} {
3333
// F16-CHECK: rock.conv({{.*}}) features = dot {[[PARMS:.*]]} : memref<[[FILTERDIMS:[x0-9]+]]xf16>, memref<[[INPUTDIMS:[x0-9]+]]xf16>, memref<[[OUTPUTDIMS:[x0-9]+]]xf16>
34-
// F16-CHECK: call @rock_conv_gkc01_ngc01_ngk01_0_gpu({{.*}}) : (memref<[[NFILTER]]xf16>, memref<[[NINPUT]]xf16>, memref<[[NOUTPUT]]xf16>) -> ()
35-
// F16-CHECK: call @rock_conv_gkc01_ngc01_ngk01_0_ver_gpu({{.*}}) : (memref<[[NFILTER]]xf32>, memref<[[NINPUT]]xf32>, memref<[[NOUTPUT]]xf32>) -> ()
34+
// F16-CHECK: call @rock_conv_gkc01_ngc01_ngk01_gpu({{.*}}) : (memref<[[NFILTER]]xf16>, memref<[[NINPUT]]xf16>, memref<[[NOUTPUT]]xf16>) -> ()
35+
// F16-CHECK: call @rock_conv_gkc01_ngc01_ngk01_ver_gpu({{.*}}) : (memref<[[NFILTER]]xf32>, memref<[[NINPUT]]xf32>, memref<[[NOUTPUT]]xf32>) -> ()
3636
// F16-CHECK: func.func @rock_conv_gkc01_ngc01_ngk01_0_ver({{.*}}) attributes {kernel = 0 : i32, mhal.arch = "{{.*}}"} {
3737
// F16-CHECK: rock.conv({{.*}}) features = dot {{{.*}} : memref<[[FILTERDIMS]]xf32>, memref<[[INPUTDIMS]]xf32>, memref<[[OUTPUTDIMS]]xf32>
3838

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,37 @@
1-
// RUN: rocmlir-gen --arch gfx900 --operation gemm -p -ph --kernel-repeats=5 | FileCheck %s
2-
// CHECK-LABEL: @rock_gemm_gpu
3-
// CHECK-DAG: %[[zero:.*]] = arith.constant 0 : index
4-
// CHECK-DAG: %[[one:.*]] = arith.constant 1 : index
5-
// CHECK-DAG: %[[five:.*]] = arith.constant 5 : index
6-
// CHECK: scf.for %{{.*}} = %[[zero]] to %[[five]] step %[[one]] {
7-
// CHECK-NEXT: func.call @rock_gemm
8-
// CHECK-NEXT: }
1+
// RUN: rocmlir-gen --arch gfx900 --operation gemm -p -ph --kernel-repeats=5 | FileCheck %s --check-prefix=GEMM
2+
// RUN: rocmlir-gen --arch gfx942 -pv --operation conv_bwd_weight -t f32 --fil_layout k01c --in_layout n01c --out_layout n01k --batchsize 64 --in_channels 1024 --in_h 14 --in_w 14 --out_channels 256 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --groupsize 1 --kernel-repeats 5 | FileCheck %s --check-prefix=CONV_WRW
3+
// RUN: rocmlir-gen --arch gfx942 -pv_with_gpu --operation conv_bwd_weight -t f32 --fil_layout k01c --in_layout n01c --out_layout n01k --batchsize 64 --in_channels 1024 --in_h 14 --in_w 14 --out_channels 256 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --groupsize 1 --kernel-repeats 5 | FileCheck %s --check-prefix=CONV_WRW_GPU
4+
5+
// GEMM-LABEL: @rock_gemm_gpu
6+
// GEMM-DAG: %[[zero:.*]] = arith.constant 0 : index
7+
// GEMM-DAG: %[[one:.*]] = arith.constant 1 : index
8+
// GEMM-DAG: %[[five:.*]] = arith.constant 5 : index
9+
// GEMM: scf.for %{{.*}} = %[[zero]] to %[[five]] step %[[one]] {
10+
// GEMM-NEXT: func.call @rock_gemm
11+
// GEMM-NEXT: }
12+
13+
// CONV_WRW-LABEL: func.func @rock_conv_bwd_weight_gk01c_n01gc_n01gk_0
14+
// CONV_WRW: rock.init_kernel
15+
// CONV_WRW-LABEL: func.func @rock_conv_bwd_weight_gk01c_n01gc_n01gk_1
16+
// CONV_WRW: rock.conv_bwd_weight
17+
// CONV_WRW-LABEL: func.func @rock_conv_bwd_weight_gk01c_n01gc_n01gk_gpu
18+
// CONV_WRW-DAG: %[[one:.*]] = arith.constant 1 : index
19+
// CONV_WRW-DAG: %[[five:.*]] = arith.constant 5 : index
20+
// CONV_WRW-DAG: %[[zero:.*]] = arith.constant 0 : index
21+
// CONV_WRW: scf.for %{{.*}} = %[[zero]] to %[[five]] step %[[one]] {
22+
// CONV_WRW-NEXT: func.call @rock_conv_bwd_weight_gk01c_n01gc_n01gk_0
23+
// CONV_WRW-NEXT: func.call @rock_conv_bwd_weight_gk01c_n01gc_n01gk_1
24+
// CONV_WRW-NEXT: }
25+
26+
// CONV_WRW_GPU-LABEL: func.func @rock_conv_bwd_weight_gk01c_n01gc_n01gk_0
27+
// CONV_WRW_GPU: rock.init_kernel
28+
// CONV_WRW_GPU-LABEL: func.func @rock_conv_bwd_weight_gk01c_n01gc_n01gk_1
29+
// CONV_WRW_GPU: rock.conv_bwd_weight
30+
// CONV_WRW_GPU-LABEL: func.func @rock_conv_bwd_weight_gk01c_n01gc_n01gk_gpu
31+
// CONV_WRW_GPU-DAG: %[[zero:.*]] = arith.constant 0 : index
32+
// CONV_WRW_GPU-DAG: %[[one:.*]] = arith.constant 1 : index
33+
// CONV_WRW_GPU-DAG: %[[five:.*]] = arith.constant 5 : index
34+
// CONV_WRW_GPU: scf.for %{{.*}} = %[[zero]] to %[[five]] step %[[one]] {
35+
// CONV_WRW_GPU-NEXT: func.call @rock_conv_bwd_weight_gk01c_n01gc_n01gk_0
36+
// CONV_WRW_GPU-NEXT: func.call @rock_conv_bwd_weight_gk01c_n01gc_n01gk_1
37+
// CONV_WRW_GPU-NEXT: }

mlir/tools/rocmlir-gen/rocmlir-gen.cpp

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#include "mlir/Support/LogicalResult.h"
5959

6060
#include "llvm/ADT/STLExtras.h"
61+
#include "llvm/ADT/SmallSet.h"
6162
#include "llvm/ADT/StringRef.h"
6263
#include "llvm/ADT/StringSwitch.h"
6364
#include "llvm/Support/CommandLine.h"
@@ -1277,18 +1278,18 @@ static Value makeNDMemRef(OpBuilder &b, Value var, uint32_t ndim) {
12771278

12781279
return var;
12791280
}
1280-
1281-
static func::FuncOp createGPUWrapper(ModuleOp module, const KernelIF &kernel) {
1281+
static func::FuncOp createGPUWrapper(ModuleOp module,
1282+
const std::string &funcName,
1283+
const SmallVector<KernelIF, 8> &kernels) {
12821284
MLIRContext *context = module.getContext();
12831285
OpBuilder b(context);
1284-
auto loc = kernel.func->getLoc();
1286+
auto loc = kernels[0].func->getLoc();
12851287

12861288
// Create gpu wrapper function
1287-
auto kfunc = kernel.func;
1288-
std::string funcName = kfunc.getName().str() + "_gpu";
1289-
auto gpuWrapperFuncType = b.getFunctionType(kernel.params, {});
1289+
std::string funcNameGpu = funcName + "_gpu";
1290+
auto gpuWrapperFuncType = b.getFunctionType(kernels[0].params, {});
12901291
auto gpuWrapperFunc =
1291-
func::FuncOp::create(loc, StringRef(funcName), gpuWrapperFuncType);
1292+
func::FuncOp::create(loc, StringRef(funcNameGpu), gpuWrapperFuncType);
12921293
module.push_back(gpuWrapperFunc);
12931294

12941295
// Emit gpu convolution logic.
@@ -1303,7 +1304,7 @@ static func::FuncOp createGPUWrapper(ModuleOp module, const KernelIF &kernel) {
13031304

13041305
SmallVector<Value, 4> cpuMem;
13051306
SmallVector<Value, 4> gpuMem;
1306-
for (auto pair : llvm::enumerate(kernel.params)) {
1307+
for (auto pair : llvm::enumerate(kernels[0].params)) {
13071308
Value arg = block->getArgument(pair.index());
13081309
cpuMem.push_back(arg);
13091310

@@ -1321,11 +1322,12 @@ static func::FuncOp createGPUWrapper(ModuleOp module, const KernelIF &kernel) {
13211322
// Emit kernel function call, repeating it if needed.
13221323
// We assume that the repeated atomic add usages in a wrw kernel will not
13231324
// substantially impact performance as the result becomes large
1324-
auto emitWrappedCall = [&kernel, &gpuMem](OpBuilder &b, Location loc,
1325-
Value ignoredIv,
1326-
ValueRange noArgs) {
1327-
auto wrappedCall = b.create<func::CallOp>(loc, kernel.func, gpuMem);
1328-
wrappedCall->setAttr("wrapped_call", b.getUnitAttr());
1325+
auto emitWrappedCall = [&kernels, &gpuMem](OpBuilder &b, Location loc,
1326+
Value ignoredIv,
1327+
ValueRange noArgs) {
1328+
for (const auto &kernel : kernels) {
1329+
b.create<func::CallOp>(loc, kernel.func, gpuMem);
1330+
}
13291331
if (ignoredIv) { // we're creating an actual loop
13301332
b.create<scf::YieldOp>(loc);
13311333
}
@@ -1341,14 +1343,12 @@ static func::FuncOp createGPUWrapper(ModuleOp module, const KernelIF &kernel) {
13411343
emitWrappedCall(b, loc, nullptr, {});
13421344
}
13431345

1344-
for (auto pair : llvm::enumerate(kernel.params)) {
1346+
for (auto pair : llvm::enumerate(kernels[0].params)) {
13451347
uint32_t i = pair.index();
13461348
b.create<gpu::MemcpyOp>(loc, TypeRange{}, ValueRange{cpuMem[i], gpuMem[i]});
13471349
b.create<gpu::DeallocOp>(loc, TypeRange{}, ValueRange{gpuMem[i]});
13481350
}
1349-
13501351
b.create<func::ReturnOp>(loc, ValueRange{});
1351-
13521352
return gpuWrapperFunc;
13531353
}
13541354

@@ -3424,35 +3424,34 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b,
34243424
}
34253425
// generate all sub-kernels, and get corresponding gemmId
34263426
std::string kernelBaseName = genConfig.kernelBaseName;
3427+
SmallVector<KernelIF, 8> kernelIFFuncs;
34273428
for (int i = kernelStart; i < kernelCount; ++i) {
34283429
convGenerator.setKernelName(kernelBaseName + "_" + std::to_string(i));
34293430
if (failed(convGenerator.genConvModule(module, i, true,
34303431
/*ignoreTuning=*/true))) {
34313432
llvm::errs() << "Module population failed.\n";
34323433
exit(1);
34333434
}
3434-
KernelIF kernel(convGenerator.getKernelFunc());
3435-
auto kernelWrapperFunc = createGPUWrapper(module, kernel);
3436-
3437-
// Decide whether to trim the last workspace argument to the verifier
3438-
// GPU kernel.
3439-
rock::ConvGenerator originalConvGenerator(genConfig);
3440-
bool originalHasWorkspace = false, verifierHasWorkspace = false;
3441-
if (failed(
3442-
originalConvGenerator.hasWorkspace(b, originalHasWorkspace))) {
3443-
llvm::errs() << "Getting workspace failed.\n";
3444-
exit(1);
3445-
}
3446-
if (failed(convGenerator.hasWorkspace(b, verifierHasWorkspace))) {
3447-
llvm::errs() << "Getting workspace failed.\n";
3448-
exit(1);
3449-
}
3450-
if (originalHasWorkspace && !verifierHasWorkspace) {
3451-
valVars.resize(valVars.size() - 1);
3452-
}
3453-
3454-
b.create<func::CallOp>(loc, kernelWrapperFunc, valVars);
3435+
kernelIFFuncs.push_back(convGenerator.getKernelFunc());
3436+
}
3437+
// Decide whether to trim the last workspace argument to the verifier
3438+
// GPU kernel.
3439+
rock::ConvGenerator originalConvGenerator(genConfig);
3440+
bool originalHasWorkspace = false, verifierHasWorkspace = false;
3441+
if (failed(originalConvGenerator.hasWorkspace(b, originalHasWorkspace))) {
3442+
llvm::errs() << "Getting workspace failed.\n";
3443+
exit(1);
3444+
}
3445+
if (failed(convGenerator.hasWorkspace(b, verifierHasWorkspace))) {
3446+
llvm::errs() << "Getting workspace failed.\n";
3447+
exit(1);
34553448
}
3449+
if (originalHasWorkspace && !verifierHasWorkspace) {
3450+
valVars.resize(valVars.size() - 1);
3451+
}
3452+
auto kernelWrapperFunc =
3453+
createGPUWrapper(module, kernelBaseName + "_ver", kernelIFFuncs);
3454+
b.create<func::CallOp>(loc, kernelWrapperFunc, valVars);
34563455
convGenerator.setKernelName(kernelBaseName);
34573456
} else { // gemm GPU validation
34583457
GenParams newParams = genParams;
@@ -3473,7 +3472,8 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b,
34733472

34743473
KernelIF kernel(
34753474
createGpuGemmKernel(module, newParams, /*isVerifier=*/true));
3476-
auto kernelWrapperFunc = createGPUWrapper(module, kernel);
3475+
auto kernelWrapperFunc =
3476+
createGPUWrapper(module, kernel.func.getName().str(), {kernel});
34773477
b.create<func::CallOp>(loc, kernelWrapperFunc, valVars);
34783478
}
34793479
} else if (validationType != "clone") { // -pv_with_cpp or -pv_with_mlir (-pv)
@@ -3759,31 +3759,33 @@ static LogicalResult populateHostHarnessLogic(
37593759

37603760
b.create<func::ReturnOp>(loc, ValueRange{});
37613761

3762-
// Wrap the kernels and gather them to substitute in calls.
3763-
llvm::SmallDenseMap<func::FuncOp, func::FuncOp> wrappedFuncs;
3762+
// Set of kernels
3763+
llvm::SmallSetVector<func::FuncOp, 4> kernelsSet;
3764+
std::string kernelBaseName =
3765+
(genParams.convConfig.has_value())
3766+
? genParams.convConfig.value()->kernelBaseName
3767+
: root0.func.getName().str();
37643768
for (auto &kernel : kernels) {
37653769
if (kernel.func->hasAttr("kernel")) {
3766-
wrappedFuncs[kernel.func] = createGPUWrapper(module, kernel);
3767-
} else {
3768-
wrappedFuncs[kernel.func] = kernel.func;
3770+
kernelsSet.insert(kernel.func);
37693771
}
37703772
}
3771-
3773+
func::FuncOp gpuWrapperFunc;
3774+
if (!kernelsSet.empty())
3775+
gpuWrapperFunc = createGPUWrapper(module, kernelBaseName, kernels);
37723776
// Redirect calls to kernel functions to point at wrapped functions.
3773-
module.walk([&](CallOpInterface callOp) -> WalkResult {
3774-
// Don't substitute the call inside the wrapper.
3775-
if (callOp->hasAttr("wrapped_call")) {
3776-
callOp->removeAttr("wrapped_call");
3777-
return WalkResult::advance();
3778-
}
3779-
3777+
func.walk([&](CallOpInterface callOp) -> WalkResult {
37803778
// If the callee matches a wrapped function, update the call.
37813779
Operation *callable = callOp.resolveCallable();
37823780
if (callable) {
37833781
func::FuncOp fop = dyn_cast<func::FuncOp>(*callable);
3784-
if (wrappedFuncs.find(fop) != wrappedFuncs.end()) {
3782+
if (kernelsSet.contains(fop)) {
3783+
if (fop != root0.func) {
3784+
callOp->erase();
3785+
return WalkResult::advance();
3786+
}
37853787
callOp->setAttr("callee", FlatSymbolRefAttr::get(
3786-
context, wrappedFuncs[fop].getSymName()));
3788+
context, gpuWrapperFunc.getSymName()));
37873789
}
37883790
}
37893791
return WalkResult::advance();

0 commit comments

Comments
 (0)