@@ -437,8 +437,8 @@ Error GenericKernelTy::init(GenericDeviceTy &GenericDevice,
437437Expected<KernelLaunchEnvironmentTy *>
438438GenericKernelTy::getKernelLaunchEnvironment (
439439 GenericDeviceTy &GenericDevice, const KernelArgsTy &KernelArgs,
440- uint32_t BlockMemSize, DynCGroupMemFallbackType DynBlockMemFb ,
441- void *DynBlockMemFbPtr, AsyncInfoWrapperTy &AsyncInfoWrapper) const {
440+ const DynBlockMemConfTy &DynBlockMemConf ,
441+ AsyncInfoWrapperTy &AsyncInfoWrapper) const {
442442 // Ctor/Dtor have no arguments, replaying uses the original kernel launch
443443 // environment. Older versions of the compiler do not generate a kernel
444444 // launch environment.
@@ -480,9 +480,9 @@ GenericKernelTy::getKernelLaunchEnvironment(
480480 LocalKLE.ReductionBuffer = nullptr ;
481481 }
482482
483- LocalKLE.DynCGroupMemSize = BlockMemSize ;
484- LocalKLE.DynCGroupMemFbPtr = DynBlockMemFbPtr ;
485- LocalKLE.DynCGroupMemFb = DynBlockMemFb ;
483+ LocalKLE.DynCGroupMemSize = DynBlockMemConf. Size ;
484+ LocalKLE.DynCGroupMemFbPtr = DynBlockMemConf. FallbackPtr ;
485+ LocalKLE.DynCGroupMemFb = DynBlockMemConf. Fallback ;
486486
487487 INFO (OMP_INFOTYPE_DATA_TRANSFER, GenericDevice.getDeviceId (),
488488 " Copying data from host to device, HstPtr=" DPxMOD " , TgtPtr=" DPxMOD
@@ -518,47 +518,51 @@ Error GenericKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
518518 return Plugin::success ();
519519}
520520
521- Expected<DynBlockMemInfoTy> prepareBlockMemory (GenericDeviceTy &GenericDevice, KernelArgsTy &KernelArgs) {
522- uint32_t MaxSize = GenericDevice.getMaxBlockSharedMemSize ();
523- uint32_t DynSize = KernelArgs.DynCGroupMem ;
524- uint32_t TotalSize = StaticSize + DynSize;
525- uint32_t DynNativeSize = DynSize;
521+ Expected<DynBlockMemConfTy>
522+ GenericKernelTy::prepareBlockMemory (GenericDeviceTy &GenericDevice,
523+ KernelArgsTy &KernelArgs,
524+ uint32_t NumBlocks) const {
525+ uint32_t MaxBlockMemSize = GenericDevice.getMaxBlockSharedMemSize ();
526+ uint32_t DynBlockMemSize = KernelArgs.DynCGroupMem ;
527+ uint32_t TotalBlockMemSize = StaticBlockMemSize + DynBlockMemSize;
528+ uint32_t DynNativeBlockMemSize = DynBlockMemSize;
526529 void *DynFallbackPtr = nullptr ;
527530
528531 // No enough block memory to cover the static one. Cannot run the kernel.
529- if (StaticSize > MaxSize )
532+ if (StaticBlockMemSize > MaxBlockMemSize )
530533 return Plugin::error (ErrorCode::INVALID_ARGUMENT,
531534 " Static block memory size exceeds maximum" );
532535 // No enough block memory to cover dynamic one, and the fallback is aborting.
533536 else if (static_cast <DynCGroupMemFallbackType>(
534537 KernelArgs.Flags .DynCGroupMemFallback ) ==
535538 DynCGroupMemFallbackType::Abort &&
536- TotalSize > MaxSize )
539+ TotalBlockMemSize > MaxBlockMemSize )
537540 return Plugin::error (
538541 ErrorCode::INVALID_ARGUMENT,
539542 " Static and dynamic block memory size exceeds maximum" );
540543
541544 DynCGroupMemFallbackType DynFallback = DynCGroupMemFallbackType::None;
542- if (DynSize && (!GenericDevice.hasNativeBlockSharedMem () ||
543- TotalSize > MaxSize )) {
545+ if (DynBlockMemSize && (!GenericDevice.hasNativeBlockSharedMem () ||
546+ TotalBlockMemSize > MaxBlockMemSize )) {
544547 // Launch without native dynamic block memory.
545- DynNativeSize = 0 ;
548+ DynNativeBlockMemSize = 0 ;
546549 DynFallback = static_cast <DynCGroupMemFallbackType>(
547550 KernelArgs.Flags .DynCGroupMemFallback );
548551 if (DynFallback == DynCGroupMemFallbackType::DefaultMem) {
549552 // Get global memory as fallback.
550553 auto AllocOrErr = GenericDevice.dataAlloc (
551- NumBlocks[ 0 ] * DynSize ,
554+ NumBlocks * DynBlockMemSize ,
552555 /* HostPtr=*/ nullptr , TargetAllocTy::TARGET_ALLOC_DEVICE);
553556 if (!AllocOrErr)
554557 return AllocOrErr.takeError ();
555558 DynFallbackPtr = *AllocOrErr;
556559 } else {
557560 // Do not provide any memory as fallback.
558- DynSize = 0 ;
561+ DynBlockMemSize = 0 ;
559562 }
560563 }
561- return { DynSize, DynNativeSize, DynFallback, DynFallbackPtr };
564+ return DynBlockMemConfTy{DynBlockMemSize, DynNativeBlockMemSize, DynFallback,
565+ DynFallbackPtr};
562566}
563567
564568Error GenericKernelTy::launch (GenericDeviceTy &GenericDevice, void **ArgPtrs,
@@ -578,17 +582,18 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
578582 NumThreads[0 ], KernelArgs.ThreadLimit [0 ] > 0 );
579583 }
580584
581- auto DynBlockMemInfoOrErr = prepareBlockMemory (GenericDevice, KernelArgs);
582- if (!DynBlockMemInfoOrErr)
583- return DynBlockMemInfoOrErr.takeError ();
585+ auto DynBlockMemConfOrErr =
586+ prepareBlockMemory (GenericDevice, KernelArgs, NumBlocks[0 ]);
587+ if (!DynBlockMemConfOrErr)
588+ return DynBlockMemConfOrErr.takeError ();
584589
585- DynBlockMemInfoTy &DynBlockMemInfo = *DynBlockMemInfoOrErr;
586- if (DynBlockMemInfo.FallbackPtr )
587- AsyncInfoWrapper.freeAllocationAfterSynchronization (DynBlockMemInfo.FallbackPtr );
590+ DynBlockMemConfTy &DynBlockMemConf = *DynBlockMemConfOrErr;
591+ if (DynBlockMemConf.FallbackPtr )
592+ AsyncInfoWrapper.freeAllocationAfterSynchronization (
593+ DynBlockMemConf.FallbackPtr );
588594
589595 auto KernelLaunchEnvOrErr = getKernelLaunchEnvironment (
590- GenericDevice, KernelArgs, DynBlockMemInfo.Size , DynBlockMemInfo.Fallback ,
591- DynBlockMemInfo.FallbackPtr , AsyncInfoWrapper);
596+ GenericDevice, KernelArgs, DynBlockMemConf, AsyncInfoWrapper);
592597 if (!KernelLaunchEnvOrErr)
593598 return KernelLaunchEnvOrErr.takeError ();
594599
@@ -619,8 +624,9 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
619624 printLaunchInfo (GenericDevice, KernelArgs, NumThreads, NumBlocks))
620625 return Err;
621626
622- return launchImpl (GenericDevice, NumThreads, NumBlocks, DynBlockMemInfo.NativeSize ,
623- KernelArgs, LaunchParams, AsyncInfoWrapper);
627+ return launchImpl (GenericDevice, NumThreads, NumBlocks,
628+ DynBlockMemConf.NativeSize , KernelArgs, LaunchParams,
629+ AsyncInfoWrapper);
624630}
625631
626632KernelLaunchParamsTy GenericKernelTy::prepareArgs (
@@ -2044,7 +2050,7 @@ InfoTreeNode GenericPluginTy::obtain_device_info(int32_t DeviceId) {
20442050 toString (std::move (Err)).data ());
20452051 return InfoTreeNode{};
20462052 }
2047- return *InfoOrErr;
2053+ return std::move ( *InfoOrErr) ;
20482054}
20492055
20502056void GenericPluginTy::print_device_info (int32_t DeviceId) {
0 commit comments