Skip to content

Commit 45d116d

Browse files
committed
Improve kernel launch
1 parent 1c4fd2e commit 45d116d

File tree

2 files changed

+64
-39
lines changed

2 files changed

+64
-39
lines changed

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,20 @@ struct GenericKernelTy {
439439
uint32_t NumBlocks[3]) const;
440440

441441
private:
442+
/// Information about the dynamic block memory needed for launching a kernel.
443+
struct DynBlockMemInfoTy {
444+
/// The size of the dynamic block memory buffer.
445+
uint32_t Size = 0;
446+
/// The size of dynamic shared memory natively provided by the device.
447+
uint32_t NativeSize = 0;
448+
/// The fallback that was triggered (if any).
449+
DynCGroupMemFallbackType DynBlockMemFb = DynCGroupMemFallbackType::None;
450+
/// The fallback pointer if global memory was used as alternative.
451+
void *FallbackPtr = nullptr;
452+
};
453+
454+
Expected<DynBlockMemInfoTy> prepareBlockMemory(GenericDeviceTy &GenericDevice, KernelArgsTy &KernelArgs);
455+
442456
/// Prepare the arguments before launching the kernel.
443457
KernelLaunchParamsTy
444458
prepareArgs(GenericDeviceTy &GenericDevice, void **ArgPtrs,

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -518,66 +518,77 @@ Error GenericKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
518518
return Plugin::success();
519519
}
520520

521-
Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
522-
ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
523-
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
524-
llvm::SmallVector<void *, 16> Args;
525-
llvm::SmallVector<void *, 16> Ptrs;
526-
527-
uint32_t NumThreads[3] = {KernelArgs.ThreadLimit[0],
528-
KernelArgs.ThreadLimit[1],
529-
KernelArgs.ThreadLimit[2]};
530-
uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
531-
KernelArgs.NumTeams[2]};
532-
if (!isBareMode()) {
533-
NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
534-
NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
535-
NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
536-
}
537-
538-
uint32_t MaxBlockMemSize = GenericDevice.getMaxBlockSharedMemSize();
539-
uint32_t DynBlockMemSize = KernelArgs.DynCGroupMem;
540-
uint32_t TotalBlockMemSize = StaticBlockMemSize + DynBlockMemSize;
541-
if (StaticBlockMemSize > MaxBlockMemSize)
521+
Expected<DynBlockMemInfoTy> prepareBlockMemory(GenericDeviceTy &GenericDevice, KernelArgsTy &KernelArgs) {
522+
uint32_t MaxSize = GenericDevice.getMaxBlockSharedMemSize();
523+
uint32_t DynSize = KernelArgs.DynCGroupMem;
524+
uint32_t TotalSize = StaticSize + DynSize;
525+
uint32_t DynNativeSize = DynSize;
526+
void *DynFallbackPtr = nullptr;
527+
528+
// No enough block memory to cover the static one. Cannot run the kernel.
529+
if (StaticSize > MaxSize)
542530
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
543531
"Static block memory size exceeds maximum");
532+
// No enough block memory to cover dynamic one, and the fallback is aborting.
544533
else if (static_cast<DynCGroupMemFallbackType>(
545534
KernelArgs.Flags.DynCGroupMemFallback) ==
546535
DynCGroupMemFallbackType::Abort &&
547-
TotalBlockMemSize > MaxBlockMemSize)
536+
TotalSize > MaxSize)
548537
return Plugin::error(
549538
ErrorCode::INVALID_ARGUMENT,
550539
"Static and dynamic block memory size exceeds maximum");
551540

552-
void *DynBlockMemFbPtr = nullptr;
553-
uint32_t DynBlockMemLaunchSize = DynBlockMemSize;
554-
555-
DynCGroupMemFallbackType DynBlockMemFb = DynCGroupMemFallbackType::None;
556-
if (DynBlockMemSize && (!GenericDevice.hasNativeBlockSharedMem() ||
557-
TotalBlockMemSize > MaxBlockMemSize)) {
541+
DynCGroupMemFallbackType DynFallback = DynCGroupMemFallbackType::None;
542+
if (DynSize && (!GenericDevice.hasNativeBlockSharedMem() ||
543+
TotalSize > MaxSize)) {
558544
// Launch without native dynamic block memory.
559-
DynBlockMemLaunchSize = 0;
560-
DynBlockMemFb = static_cast<DynCGroupMemFallbackType>(
545+
DynNativeSize = 0;
546+
DynFallback = static_cast<DynCGroupMemFallbackType>(
561547
KernelArgs.Flags.DynCGroupMemFallback);
562-
if (DynBlockMemFb == DynCGroupMemFallbackType::DefaultMem) {
548+
if (DynFallback == DynCGroupMemFallbackType::DefaultMem) {
563549
// Get global memory as fallback.
564550
auto AllocOrErr = GenericDevice.dataAlloc(
565-
NumBlocks[0] * DynBlockMemSize,
551+
NumBlocks[0] * DynSize,
566552
/*HostPtr=*/nullptr, TargetAllocTy::TARGET_ALLOC_DEVICE);
567553
if (!AllocOrErr)
568554
return AllocOrErr.takeError();
569-
570-
DynBlockMemFbPtr = *AllocOrErr;
571-
AsyncInfoWrapper.freeAllocationAfterSynchronization(DynBlockMemFbPtr);
555+
DynFallbackPtr = *AllocOrErr;
572556
} else {
573557
// Do not provide any memory as fallback.
574-
DynBlockMemSize = 0;
558+
DynSize = 0;
575559
}
576560
}
561+
return { DynSize, DynNativeSize, DynFallback, DynFallbackPtr };
562+
}
563+
564+
Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
565+
ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
566+
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
567+
llvm::SmallVector<void *, 16> Args;
568+
llvm::SmallVector<void *, 16> Ptrs;
569+
570+
uint32_t NumThreads[3] = {KernelArgs.ThreadLimit[0],
571+
KernelArgs.ThreadLimit[1],
572+
KernelArgs.ThreadLimit[2]};
573+
uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
574+
KernelArgs.NumTeams[2]};
575+
if (!isBareMode()) {
576+
NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
577+
NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
578+
NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
579+
}
580+
581+
auto DynBlockMemInfoOrErr = prepareBlockMemory(GenericDevice, KernelArgs);
582+
if (!DynBlockMemInfoOrErr)
583+
return DynBlockMemInfoOrErr.takeError();
584+
585+
DynBlockMemInfoTy &DynBlockMemInfo = *DynBlockMemInfoOrErr;
586+
if (DynBlockMemInfo.FallbackPtr)
587+
AsyncInfoWrapper.freeAllocationAfterSynchronization(DynBlockMemInfo.FallbackPtr);
577588

578589
auto KernelLaunchEnvOrErr = getKernelLaunchEnvironment(
579-
GenericDevice, KernelArgs, DynBlockMemSize, DynBlockMemFb,
580-
DynBlockMemFbPtr, AsyncInfoWrapper);
590+
GenericDevice, KernelArgs, DynBlockMemInfo.Size, DynBlockMemInfo.Fallback,
591+
DynBlockMemInfo.FallbackPtr, AsyncInfoWrapper);
581592
if (!KernelLaunchEnvOrErr)
582593
return KernelLaunchEnvOrErr.takeError();
583594

@@ -608,7 +619,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
608619
printLaunchInfo(GenericDevice, KernelArgs, NumThreads, NumBlocks))
609620
return Err;
610621

611-
return launchImpl(GenericDevice, NumThreads, NumBlocks, DynBlockMemLaunchSize,
622+
return launchImpl(GenericDevice, NumThreads, NumBlocks, DynBlockMemInfo.NativeSize,
612623
KernelArgs, LaunchParams, AsyncInfoWrapper);
613624
}
614625

0 commit comments

Comments
 (0)