5858#include " mlir/Support/LogicalResult.h"
5959
6060#include " llvm/ADT/STLExtras.h"
61+ #include " llvm/ADT/SmallSet.h"
6162#include " llvm/ADT/StringRef.h"
6263#include " llvm/ADT/StringSwitch.h"
6364#include " llvm/Support/CommandLine.h"
@@ -1277,18 +1278,18 @@ static Value makeNDMemRef(OpBuilder &b, Value var, uint32_t ndim) {
12771278
12781279 return var;
12791280}
1280-
1281- static func::FuncOp createGPUWrapper (ModuleOp module , const KernelIF &kernel) {
1281+ static func::FuncOp createGPUWrapper (ModuleOp module ,
1282+ const std::string &funcName,
1283+ const SmallVector<KernelIF, 8 > &kernels) {
12821284 MLIRContext *context = module .getContext ();
12831285 OpBuilder b (context);
1284- auto loc = kernel .func ->getLoc ();
1286+ auto loc = kernels[ 0 ] .func ->getLoc ();
12851287
12861288 // Create gpu wrapper function
1287- auto kfunc = kernel.func ;
1288- std::string funcName = kfunc.getName ().str () + " _gpu" ;
1289- auto gpuWrapperFuncType = b.getFunctionType (kernel.params , {});
1289+ std::string funcNameGpu = funcName + " _gpu" ;
1290+ auto gpuWrapperFuncType = b.getFunctionType (kernels[0 ].params , {});
12901291 auto gpuWrapperFunc =
1291- func::FuncOp::create (loc, StringRef (funcName ), gpuWrapperFuncType);
1292+ func::FuncOp::create (loc, StringRef (funcNameGpu ), gpuWrapperFuncType);
12921293 module .push_back (gpuWrapperFunc);
12931294
12941295 // Emit gpu convolution logic.
@@ -1303,7 +1304,7 @@ static func::FuncOp createGPUWrapper(ModuleOp module, const KernelIF &kernel) {
13031304
13041305 SmallVector<Value, 4 > cpuMem;
13051306 SmallVector<Value, 4 > gpuMem;
1306- for (auto pair : llvm::enumerate (kernel .params )) {
1307+ for (auto pair : llvm::enumerate (kernels[ 0 ] .params )) {
13071308 Value arg = block->getArgument (pair.index ());
13081309 cpuMem.push_back (arg);
13091310
@@ -1321,11 +1322,12 @@ static func::FuncOp createGPUWrapper(ModuleOp module, const KernelIF &kernel) {
13211322 // Emit kernel function call, repeating it if needed.
13221323 // We assume that the repeated atomic add usages in a wrw kernel will not
13231324 // substantially impact performance as the result becomes large
1324- auto emitWrappedCall = [&kernel, &gpuMem](OpBuilder &b, Location loc,
1325- Value ignoredIv,
1326- ValueRange noArgs) {
1327- auto wrappedCall = b.create <func::CallOp>(loc, kernel.func , gpuMem);
1328- wrappedCall->setAttr (" wrapped_call" , b.getUnitAttr ());
1325+ auto emitWrappedCall = [&kernels, &gpuMem](OpBuilder &b, Location loc,
1326+ Value ignoredIv,
1327+ ValueRange noArgs) {
1328+ for (const auto &kernel : kernels) {
1329+ b.create <func::CallOp>(loc, kernel.func , gpuMem);
1330+ }
13291331 if (ignoredIv) { // we're creating an actual loop
13301332 b.create <scf::YieldOp>(loc);
13311333 }
@@ -1341,14 +1343,12 @@ static func::FuncOp createGPUWrapper(ModuleOp module, const KernelIF &kernel) {
13411343 emitWrappedCall (b, loc, nullptr , {});
13421344 }
13431345
1344- for (auto pair : llvm::enumerate (kernel .params )) {
1346+ for (auto pair : llvm::enumerate (kernels[ 0 ] .params )) {
13451347 uint32_t i = pair.index ();
13461348 b.create <gpu::MemcpyOp>(loc, TypeRange{}, ValueRange{cpuMem[i], gpuMem[i]});
13471349 b.create <gpu::DeallocOp>(loc, TypeRange{}, ValueRange{gpuMem[i]});
13481350 }
1349-
13501351 b.create <func::ReturnOp>(loc, ValueRange{});
1351-
13521352 return gpuWrapperFunc;
13531353}
13541354
@@ -3424,35 +3424,34 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b,
34243424 }
34253425 // generate all sub-kernels, and get corresponding gemmId
34263426 std::string kernelBaseName = genConfig.kernelBaseName ;
3427+ SmallVector<KernelIF, 8 > kernelIFFuncs;
34273428 for (int i = kernelStart; i < kernelCount; ++i) {
34283429 convGenerator.setKernelName (kernelBaseName + " _" + std::to_string (i));
34293430 if (failed (convGenerator.genConvModule (module , i, true ,
34303431 /* ignoreTuning=*/ true ))) {
34313432 llvm::errs () << " Module population failed.\n " ;
34323433 exit (1 );
34333434 }
3434- KernelIF kernel (convGenerator.getKernelFunc ());
3435- auto kernelWrapperFunc = createGPUWrapper (module , kernel);
3436-
3437- // Decide whether to trim the last workspace argument to the verifier
3438- // GPU kernel.
3439- rock::ConvGenerator originalConvGenerator (genConfig);
3440- bool originalHasWorkspace = false , verifierHasWorkspace = false ;
3441- if (failed (
3442- originalConvGenerator.hasWorkspace (b, originalHasWorkspace))) {
3443- llvm::errs () << " Getting workspace failed.\n " ;
3444- exit (1 );
3445- }
3446- if (failed (convGenerator.hasWorkspace (b, verifierHasWorkspace))) {
3447- llvm::errs () << " Getting workspace failed.\n " ;
3448- exit (1 );
3449- }
3450- if (originalHasWorkspace && !verifierHasWorkspace) {
3451- valVars.resize (valVars.size () - 1 );
3452- }
3453-
3454- b.create <func::CallOp>(loc, kernelWrapperFunc, valVars);
3435+ kernelIFFuncs.push_back (convGenerator.getKernelFunc ());
3436+ }
3437+ // Decide whether to trim the last workspace argument to the verifier
3438+ // GPU kernel.
3439+ rock::ConvGenerator originalConvGenerator (genConfig);
3440+ bool originalHasWorkspace = false , verifierHasWorkspace = false ;
3441+ if (failed (originalConvGenerator.hasWorkspace (b, originalHasWorkspace))) {
3442+ llvm::errs () << " Getting workspace failed.\n " ;
3443+ exit (1 );
3444+ }
3445+ if (failed (convGenerator.hasWorkspace (b, verifierHasWorkspace))) {
3446+ llvm::errs () << " Getting workspace failed.\n " ;
3447+ exit (1 );
34553448 }
3449+ if (originalHasWorkspace && !verifierHasWorkspace) {
3450+ valVars.resize (valVars.size () - 1 );
3451+ }
3452+ auto kernelWrapperFunc =
3453+ createGPUWrapper (module , kernelBaseName + " _ver" , kernelIFFuncs);
3454+ b.create <func::CallOp>(loc, kernelWrapperFunc, valVars);
34563455 convGenerator.setKernelName (kernelBaseName);
34573456 } else { // gemm GPU validation
34583457 GenParams newParams = genParams;
@@ -3473,7 +3472,8 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b,
34733472
34743473 KernelIF kernel (
34753474 createGpuGemmKernel (module , newParams, /* isVerifier=*/ true ));
3476- auto kernelWrapperFunc = createGPUWrapper (module , kernel);
3475+ auto kernelWrapperFunc =
3476+ createGPUWrapper (module , kernel.func .getName ().str (), {kernel});
34773477 b.create <func::CallOp>(loc, kernelWrapperFunc, valVars);
34783478 }
34793479 } else if (validationType != " clone" ) { // -pv_with_cpp or -pv_with_mlir (-pv)
@@ -3759,31 +3759,33 @@ static LogicalResult populateHostHarnessLogic(
37593759
37603760 b.create <func::ReturnOp>(loc, ValueRange{});
37613761
3762- // Wrap the kernels and gather them to substitute in calls.
3763- llvm::SmallDenseMap<func::FuncOp, func::FuncOp> wrappedFuncs;
3762+ // Set of kernels
3763+ llvm::SmallSetVector<func::FuncOp, 4 > kernelsSet;
3764+ std::string kernelBaseName =
3765+ (genParams.convConfig .has_value ())
3766+ ? genParams.convConfig .value ()->kernelBaseName
3767+ : root0.func .getName ().str ();
37643768 for (auto &kernel : kernels) {
37653769 if (kernel.func ->hasAttr (" kernel" )) {
3766- wrappedFuncs[kernel.func ] = createGPUWrapper (module , kernel);
3767- } else {
3768- wrappedFuncs[kernel.func ] = kernel.func ;
3770+ kernelsSet.insert (kernel.func );
37693771 }
37703772 }
3771-
3773+ func::FuncOp gpuWrapperFunc;
3774+ if (!kernelsSet.empty ())
3775+ gpuWrapperFunc = createGPUWrapper (module , kernelBaseName, kernels);
37723776 // Redirect calls to kernel functions to point at wrapped functions.
3773- module .walk ([&](CallOpInterface callOp) -> WalkResult {
3774- // Don't substitute the call inside the wrapper.
3775- if (callOp->hasAttr (" wrapped_call" )) {
3776- callOp->removeAttr (" wrapped_call" );
3777- return WalkResult::advance ();
3778- }
3779-
3777+ func.walk ([&](CallOpInterface callOp) -> WalkResult {
37803778 // If the callee matches a wrapped function, update the call.
37813779 Operation *callable = callOp.resolveCallable ();
37823780 if (callable) {
37833781 func::FuncOp fop = dyn_cast<func::FuncOp>(*callable);
3784- if (wrappedFuncs.find (fop) != wrappedFuncs.end ()) {
3782+ if (kernelsSet.contains (fop)) {
3783+ if (fop != root0.func ) {
3784+ callOp->erase ();
3785+ return WalkResult::advance ();
3786+ }
37853787 callOp->setAttr (" callee" , FlatSymbolRefAttr::get (
3786- context, wrappedFuncs[fop] .getSymName ()));
3788+ context, gpuWrapperFunc .getSymName ()));
37873789 }
37883790 }
37893791 return WalkResult::advance ();
0 commit comments