@@ -518,66 +518,77 @@ Error GenericKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
518518 return Plugin::success ();
519519}
520520
521- Error GenericKernelTy::launch (GenericDeviceTy &GenericDevice, void **ArgPtrs,
522- ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
523- AsyncInfoWrapperTy &AsyncInfoWrapper) const {
524- llvm::SmallVector<void *, 16 > Args;
525- llvm::SmallVector<void *, 16 > Ptrs;
526-
527- uint32_t NumThreads[3 ] = {KernelArgs.ThreadLimit [0 ],
528- KernelArgs.ThreadLimit [1 ],
529- KernelArgs.ThreadLimit [2 ]};
530- uint32_t NumBlocks[3 ] = {KernelArgs.NumTeams [0 ], KernelArgs.NumTeams [1 ],
531- KernelArgs.NumTeams [2 ]};
532- if (!isBareMode ()) {
533- NumThreads[0 ] = getNumThreads (GenericDevice, NumThreads);
534- NumBlocks[0 ] = getNumBlocks (GenericDevice, NumBlocks, KernelArgs.Tripcount ,
535- NumThreads[0 ], KernelArgs.ThreadLimit [0 ] > 0 );
536- }
537-
538- uint32_t MaxBlockMemSize = GenericDevice.getMaxBlockSharedMemSize ();
539- uint32_t DynBlockMemSize = KernelArgs.DynCGroupMem ;
540- uint32_t TotalBlockMemSize = StaticBlockMemSize + DynBlockMemSize;
541- if (StaticBlockMemSize > MaxBlockMemSize)
521+ Expected<DynBlockMemInfoTy> prepareBlockMemory (GenericDeviceTy &GenericDevice, KernelArgsTy &KernelArgs) {
522+ uint32_t MaxSize = GenericDevice.getMaxBlockSharedMemSize ();
523+ uint32_t DynSize = KernelArgs.DynCGroupMem ;
524+ uint32_t TotalSize = StaticSize + DynSize;
525+ uint32_t DynNativeSize = DynSize;
526+ void *DynFallbackPtr = nullptr ;
527+
528+ // No enough block memory to cover the static one. Cannot run the kernel.
529+ if (StaticSize > MaxSize)
542530 return Plugin::error (ErrorCode::INVALID_ARGUMENT,
543531 " Static block memory size exceeds maximum" );
532+ // No enough block memory to cover dynamic one, and the fallback is aborting.
544533 else if (static_cast <DynCGroupMemFallbackType>(
545534 KernelArgs.Flags .DynCGroupMemFallback ) ==
546535 DynCGroupMemFallbackType::Abort &&
547- TotalBlockMemSize > MaxBlockMemSize )
536+ TotalSize > MaxSize )
548537 return Plugin::error (
549538 ErrorCode::INVALID_ARGUMENT,
550539 " Static and dynamic block memory size exceeds maximum" );
551540
552- void *DynBlockMemFbPtr = nullptr ;
553- uint32_t DynBlockMemLaunchSize = DynBlockMemSize;
554-
555- DynCGroupMemFallbackType DynBlockMemFb = DynCGroupMemFallbackType::None;
556- if (DynBlockMemSize && (!GenericDevice.hasNativeBlockSharedMem () ||
557- TotalBlockMemSize > MaxBlockMemSize)) {
541+ DynCGroupMemFallbackType DynFallback = DynCGroupMemFallbackType::None;
542+ if (DynSize && (!GenericDevice.hasNativeBlockSharedMem () ||
543+ TotalSize > MaxSize)) {
558544 // Launch without native dynamic block memory.
559- DynBlockMemLaunchSize = 0 ;
560- DynBlockMemFb = static_cast <DynCGroupMemFallbackType>(
545+ DynNativeSize = 0 ;
546+ DynFallback = static_cast <DynCGroupMemFallbackType>(
561547 KernelArgs.Flags .DynCGroupMemFallback );
562- if (DynBlockMemFb == DynCGroupMemFallbackType::DefaultMem) {
548+ if (DynFallback == DynCGroupMemFallbackType::DefaultMem) {
563549 // Get global memory as fallback.
564550 auto AllocOrErr = GenericDevice.dataAlloc (
565- NumBlocks[0 ] * DynBlockMemSize ,
551+ NumBlocks[0 ] * DynSize ,
566552 /* HostPtr=*/ nullptr , TargetAllocTy::TARGET_ALLOC_DEVICE);
567553 if (!AllocOrErr)
568554 return AllocOrErr.takeError ();
569-
570- DynBlockMemFbPtr = *AllocOrErr;
571- AsyncInfoWrapper.freeAllocationAfterSynchronization (DynBlockMemFbPtr);
555+ DynFallbackPtr = *AllocOrErr;
572556 } else {
573557 // Do not provide any memory as fallback.
574- DynBlockMemSize = 0 ;
558+ DynSize = 0 ;
575559 }
576560 }
561+ return { DynSize, DynNativeSize, DynFallback, DynFallbackPtr };
562+ }
563+
564+ Error GenericKernelTy::launch (GenericDeviceTy &GenericDevice, void **ArgPtrs,
565+ ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
566+ AsyncInfoWrapperTy &AsyncInfoWrapper) const {
567+ llvm::SmallVector<void *, 16 > Args;
568+ llvm::SmallVector<void *, 16 > Ptrs;
569+
570+ uint32_t NumThreads[3 ] = {KernelArgs.ThreadLimit [0 ],
571+ KernelArgs.ThreadLimit [1 ],
572+ KernelArgs.ThreadLimit [2 ]};
573+ uint32_t NumBlocks[3 ] = {KernelArgs.NumTeams [0 ], KernelArgs.NumTeams [1 ],
574+ KernelArgs.NumTeams [2 ]};
575+ if (!isBareMode ()) {
576+ NumThreads[0 ] = getNumThreads (GenericDevice, NumThreads);
577+ NumBlocks[0 ] = getNumBlocks (GenericDevice, NumBlocks, KernelArgs.Tripcount ,
578+ NumThreads[0 ], KernelArgs.ThreadLimit [0 ] > 0 );
579+ }
580+
581+ auto DynBlockMemInfoOrErr = prepareBlockMemory (GenericDevice, KernelArgs);
582+ if (!DynBlockMemInfoOrErr)
583+ return DynBlockMemInfoOrErr.takeError ();
584+
585+ DynBlockMemInfoTy &DynBlockMemInfo = *DynBlockMemInfoOrErr;
586+ if (DynBlockMemInfo.FallbackPtr )
587+ AsyncInfoWrapper.freeAllocationAfterSynchronization (DynBlockMemInfo.FallbackPtr );
577588
578589 auto KernelLaunchEnvOrErr = getKernelLaunchEnvironment (
579- GenericDevice, KernelArgs, DynBlockMemSize, DynBlockMemFb ,
580- DynBlockMemFbPtr , AsyncInfoWrapper);
590+ GenericDevice, KernelArgs, DynBlockMemInfo. Size , DynBlockMemInfo. Fallback ,
591+ DynBlockMemInfo. FallbackPtr , AsyncInfoWrapper);
581592 if (!KernelLaunchEnvOrErr)
582593 return KernelLaunchEnvOrErr.takeError ();
583594
@@ -608,7 +619,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
608619 printLaunchInfo (GenericDevice, KernelArgs, NumThreads, NumBlocks))
609620 return Err;
610621
611- return launchImpl (GenericDevice, NumThreads, NumBlocks, DynBlockMemLaunchSize ,
622+ return launchImpl (GenericDevice, NumThreads, NumBlocks, DynBlockMemInfo. NativeSize ,
612623 KernelArgs, LaunchParams, AsyncInfoWrapper);
613624}
614625
0 commit comments