@@ -437,7 +437,8 @@ Error GenericKernelTy::init(GenericDeviceTy &GenericDevice,
437437Expected<KernelLaunchEnvironmentTy *>
438438GenericKernelTy::getKernelLaunchEnvironment (
439439 GenericDeviceTy &GenericDevice, const KernelArgsTy &KernelArgs,
440- void *FallbackBlockMem, AsyncInfoWrapperTy &AsyncInfoWrapper) const {
440+ uint32_t BlockMemSize, DynCGroupMemFallbackType DynBlockMemFb,
441+ void *DynBlockMemFbPtr, AsyncInfoWrapperTy &AsyncInfoWrapper) const {
441442 // Ctor/Dtor have no arguments, replaying uses the original kernel launch
442443 // environment. Older versions of the compiler do not generate a kernel
443444 // launch environment.
@@ -479,8 +480,9 @@ GenericKernelTy::getKernelLaunchEnvironment(
479480 LocalKLE.ReductionBuffer = nullptr ;
480481 }
481482
482- LocalKLE.DynCGroupMemSize = KernelArgs.DynCGroupMem ;
483- LocalKLE.DynCGroupMemFallback = FallbackBlockMem;
483+ LocalKLE.DynCGroupMemSize = BlockMemSize;
484+ LocalKLE.DynCGroupMemFbPtr = DynBlockMemFbPtr;
485+ LocalKLE.DynCGroupMemFb = DynBlockMemFb;
484486
485487 INFO (OMP_INFOTYPE_DATA_TRANSFER, GenericDevice.getDeviceId (),
486488 " Copying data from host to device, HstPtr=" DPxMOD " , TgtPtr=" DPxMOD
@@ -539,28 +541,43 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
539541 if (StaticBlockMemSize > MaxBlockMemSize)
540542 return Plugin::error (ErrorCode::INVALID_ARGUMENT,
541543 " Static block memory size exceeds maximum" );
542- else if (!KernelArgs.Flags .AllowDynCGroupMemFallback &&
544+ else if (static_cast <DynCGroupMemFallbackType>(
545+ KernelArgs.Flags .DynCGroupMemFallback ) ==
546+ DynCGroupMemFallbackType::Abort &&
543547 TotalBlockMemSize > MaxBlockMemSize)
544548 return Plugin::error (
545549 ErrorCode::INVALID_ARGUMENT,
546550 " Static and dynamic block memory size exceeds maximum" );
547551
548- void *FallbackBlockMem = nullptr ;
552+ void *DynBlockMemFbPtr = nullptr ;
553+ uint32_t DynBlockMemLaunchSize = DynBlockMemSize;
554+
555+ DynCGroupMemFallbackType DynBlockMemFb = DynCGroupMemFallbackType::None;
549556 if (DynBlockMemSize && (!GenericDevice.hasNativeBlockSharedMem () ||
550557 TotalBlockMemSize > MaxBlockMemSize)) {
551- auto AllocOrErr = GenericDevice.dataAlloc (
552- NumBlocks[0 ] * DynBlockMemSize,
553- /* HostPtr=*/ nullptr , TargetAllocTy::TARGET_ALLOC_DEVICE);
554- if (!AllocOrErr)
555- return AllocOrErr.takeError ();
558+ // Launch without native dynamic block memory.
559+ DynBlockMemLaunchSize = 0 ;
560+ DynBlockMemFb = static_cast <DynCGroupMemFallbackType>(
561+ KernelArgs.Flags .DynCGroupMemFallback );
562+ if (DynBlockMemFb == DynCGroupMemFallbackType::DefaultMem) {
563+ // Get global memory as fallback.
564+ auto AllocOrErr = GenericDevice.dataAlloc (
565+ NumBlocks[0 ] * DynBlockMemSize,
566+ /* HostPtr=*/ nullptr , TargetAllocTy::TARGET_ALLOC_DEVICE);
567+ if (!AllocOrErr)
568+ return AllocOrErr.takeError ();
556569
557- FallbackBlockMem = *AllocOrErr;
558- AsyncInfoWrapper.freeAllocationAfterSynchronization (FallbackBlockMem);
559- DynBlockMemSize = 0 ;
570+ DynBlockMemFbPtr = *AllocOrErr;
571+ AsyncInfoWrapper.freeAllocationAfterSynchronization (DynBlockMemFbPtr);
572+ } else {
573+ // Do not provide any memory as fallback.
574+ DynBlockMemSize = 0 ;
575+ }
560576 }
561577
562578 auto KernelLaunchEnvOrErr = getKernelLaunchEnvironment (
563- GenericDevice, KernelArgs, FallbackBlockMem, AsyncInfoWrapper);
579+ GenericDevice, KernelArgs, DynBlockMemSize, DynBlockMemFb,
580+ DynBlockMemFbPtr, AsyncInfoWrapper);
564581 if (!KernelLaunchEnvOrErr)
565582 return KernelLaunchEnvOrErr.takeError ();
566583
@@ -591,7 +608,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
591608 printLaunchInfo (GenericDevice, KernelArgs, NumThreads, NumBlocks))
592609 return Err;
593610
594- return launchImpl (GenericDevice, NumThreads, NumBlocks, DynBlockMemSize ,
611+ return launchImpl (GenericDevice, NumThreads, NumBlocks, DynBlockMemLaunchSize ,
595612 KernelArgs, LaunchParams, AsyncInfoWrapper);
596613}
597614
0 commit comments