@@ -559,15 +559,15 @@ struct AMDGPUKernelTy : public GenericKernelTy {
559559 }
560560
561561 // / Launch the AMDGPU kernel function.
562- Error launchImpl (GenericDeviceTy &GenericDevice, uint32_t NumThreads,
563- uint64_t NumBlocks, KernelArgsTy &KernelArgs,
562+ Error launchImpl (GenericDeviceTy &GenericDevice, uint32_t NumThreads[ 3 ] ,
563+ uint32_t NumBlocks[ 3 ] , KernelArgsTy &KernelArgs,
564564 KernelLaunchParamsTy LaunchParams,
565565 AsyncInfoWrapperTy &AsyncInfoWrapper) const override ;
566566
567567 // / Print more elaborate kernel launch info for AMDGPU
568568 Error printLaunchInfoDetails (GenericDeviceTy &GenericDevice,
569- KernelArgsTy &KernelArgs, uint32_t NumThreads,
570- uint64_t NumBlocks) const override ;
569+ KernelArgsTy &KernelArgs, uint32_t NumThreads[ 3 ] ,
570+ uint32_t NumBlocks[ 3 ] ) const override ;
571571
572572 // / Get group and private segment kernel size.
573573 uint32_t getGroupSize () const { return GroupSize; }
@@ -719,7 +719,7 @@ struct AMDGPUQueueTy {
719719 // / Push a kernel launch to the queue. The kernel launch requires an output
720720 // / signal and can define an optional input signal (nullptr if none).
721721 Error pushKernelLaunch (const AMDGPUKernelTy &Kernel, void *KernelArgs,
722- uint32_t NumThreads, uint64_t NumBlocks,
722+ uint32_t NumThreads[ 3 ], uint32_t NumBlocks[ 3 ] ,
723723 uint32_t GroupSize, uint64_t StackSize,
724724 AMDGPUSignalTy *OutputSignal,
725725 AMDGPUSignalTy *InputSignal) {
@@ -746,14 +746,18 @@ struct AMDGPUQueueTy {
746746 assert (Packet && " Invalid packet" );
747747
748748 // The first 32 bits of the packet are written after the other fields
749- uint16_t Setup = UINT16_C (1 ) << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
750- Packet->workgroup_size_x = NumThreads;
751- Packet->workgroup_size_y = 1 ;
752- Packet->workgroup_size_z = 1 ;
749+ uint16_t Dims = NumBlocks[2 ] * NumThreads[2 ] > 1
750+ ? 3
751+ : 1 + (NumBlocks[1 ] * NumThreads[1 ] != 1 );
752+ uint16_t Setup = UINT16_C (Dims)
753+ << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
754+ Packet->workgroup_size_x = NumThreads[0 ];
755+ Packet->workgroup_size_y = NumThreads[1 ];
756+ Packet->workgroup_size_z = NumThreads[2 ];
753757 Packet->reserved0 = 0 ;
754- Packet->grid_size_x = NumBlocks * NumThreads;
755- Packet->grid_size_y = 1 ;
756- Packet->grid_size_z = 1 ;
758+ Packet->grid_size_x = NumBlocks[ 0 ] * NumThreads[ 0 ] ;
759+ Packet->grid_size_y = NumBlocks[ 1 ] * NumThreads[ 1 ] ;
760+ Packet->grid_size_z = NumBlocks[ 2 ] * NumThreads[ 2 ] ;
757761 Packet->private_segment_size =
758762 Kernel.usesDynamicStack () ? StackSize : Kernel.getPrivateSize ();
759763 Packet->group_segment_size = GroupSize;
@@ -1240,7 +1244,7 @@ struct AMDGPUStreamTy {
12401244 // / the kernel finalizes. Once the kernel is finished, the stream will release
12411245 // / the kernel args buffer to the specified memory manager.
12421246 Error pushKernelLaunch (const AMDGPUKernelTy &Kernel, void *KernelArgs,
1243- uint32_t NumThreads, uint64_t NumBlocks,
1247+ uint32_t NumThreads[ 3 ], uint32_t NumBlocks[ 3 ] ,
12441248 uint32_t GroupSize, uint64_t StackSize,
12451249 AMDGPUMemoryManagerTy &MemoryManager) {
12461250 if (Queue == nullptr )
@@ -2829,10 +2833,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
28292833 AsyncInfoWrapperTy AsyncInfoWrapper (*this , nullptr );
28302834
28312835 KernelArgsTy KernelArgs = {};
2832- if ( auto Err =
2833- AMDGPUKernel.launchImpl (* this , /* NumThread= */ 1u ,
2834- /* NumBlocks= */ 1ul , KernelArgs,
2835- KernelLaunchParamsTy{}, AsyncInfoWrapper))
2836+ uint32_t NumBlocksAndThreads[ 3 ] = { 1u , 1u , 1u };
2837+ if ( auto Err = AMDGPUKernel.launchImpl (
2838+ * this , NumBlocksAndThreads, NumBlocksAndThreads , KernelArgs,
2839+ KernelLaunchParamsTy{}, AsyncInfoWrapper))
28362840 return Err;
28372841
28382842 Error Err = Plugin::success ();
@@ -3330,7 +3334,7 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
33303334};
33313335
33323336Error AMDGPUKernelTy::launchImpl (GenericDeviceTy &GenericDevice,
3333- uint32_t NumThreads, uint64_t NumBlocks,
3337+ uint32_t NumThreads[ 3 ], uint32_t NumBlocks[ 3 ] ,
33343338 KernelArgsTy &KernelArgs,
33353339 KernelLaunchParamsTy LaunchParams,
33363340 AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -3387,13 +3391,15 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
33873391 // Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
33883392 if (ImplArgs &&
33893393 getImplicitArgsSize () == sizeof (hsa_utils::AMDGPUImplicitArgsTy)) {
3390- ImplArgs->BlockCountX = NumBlocks;
3391- ImplArgs->BlockCountY = 1 ;
3392- ImplArgs->BlockCountZ = 1 ;
3393- ImplArgs->GroupSizeX = NumThreads;
3394- ImplArgs->GroupSizeY = 1 ;
3395- ImplArgs->GroupSizeZ = 1 ;
3396- ImplArgs->GridDims = 1 ;
3394+ ImplArgs->BlockCountX = NumBlocks[0 ];
3395+ ImplArgs->BlockCountY = NumBlocks[1 ];
3396+ ImplArgs->BlockCountZ = NumBlocks[2 ];
3397+ ImplArgs->GroupSizeX = NumThreads[0 ];
3398+ ImplArgs->GroupSizeY = NumThreads[1 ];
3399+ ImplArgs->GroupSizeZ = NumThreads[2 ];
3400+ ImplArgs->GridDims = NumBlocks[2 ] * NumThreads[2 ] > 1
3401+ ? 3
3402+ : 1 + (NumBlocks[1 ] * NumThreads[1 ] != 1 );
33973403 ImplArgs->DynamicLdsSize = KernelArgs.DynCGroupMem ;
33983404 }
33993405
@@ -3404,8 +3410,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
34043410
34053411Error AMDGPUKernelTy::printLaunchInfoDetails (GenericDeviceTy &GenericDevice,
34063412 KernelArgsTy &KernelArgs,
3407- uint32_t NumThreads,
3408- uint64_t NumBlocks) const {
3413+ uint32_t NumThreads[ 3 ] ,
3414+ uint32_t NumBlocks[ 3 ] ) const {
34093415 // Only do all this when the output is requested
34103416 if (!(getInfoLevel () & OMP_INFOTYPE_PLUGIN_KERNEL))
34113417 return Plugin::success ();
@@ -3442,12 +3448,13 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
34423448 // S/VGPR Spill Count: how many S/VGPRs are spilled by the kernel
34433449 // Tripcount: loop tripcount for the kernel
34443450 INFO (OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId (),
3445- " #Args: %d Teams x Thrds: %4lux %4u (MaxFlatWorkGroupSize: %u) LDS "
3451+ " #Args: %d Teams x Thrds: %4ux %4u (MaxFlatWorkGroupSize: %u) LDS "
34463452 " Usage: %uB #SGPRs/VGPRs: %u/%u #SGPR/VGPR Spills: %u/%u Tripcount: "
34473453 " %lu\n " ,
3448- ArgNum, NumGroups, ThreadsPerGroup, MaxFlatWorkgroupSize,
3449- GroupSegmentSize, SGPRCount, VGPRCount, SGPRSpillCount, VGPRSpillCount,
3450- LoopTripCount);
3454+ ArgNum, NumGroups[0 ] * NumGroups[1 ] * NumGroups[2 ],
3455+ ThreadsPerGroup[0 ] * ThreadsPerGroup[1 ] * ThreadsPerGroup[2 ],
3456+ MaxFlatWorkgroupSize, GroupSegmentSize, SGPRCount, VGPRCount,
3457+ SGPRSpillCount, VGPRSpillCount, LoopTripCount);
34513458
34523459 return Plugin::success ();
34533460}
0 commit comments