@@ -559,15 +559,15 @@ struct AMDGPUKernelTy : public GenericKernelTy {
559559 }
560560
561561 // / Launch the AMDGPU kernel function.
562- Error launchImpl (GenericDeviceTy &GenericDevice, uint32_t NumThreads,
563- uint64_t NumBlocks, KernelArgsTy &KernelArgs,
562+ Error launchImpl (GenericDeviceTy &GenericDevice, uint32_t NumThreads[ 3 ] ,
563+ uint32_t NumBlocks[ 3 ] , KernelArgsTy &KernelArgs,
564564 KernelLaunchParamsTy LaunchParams,
565565 AsyncInfoWrapperTy &AsyncInfoWrapper) const override ;
566566
567567 // / Print more elaborate kernel launch info for AMDGPU
568568 Error printLaunchInfoDetails (GenericDeviceTy &GenericDevice,
569- KernelArgsTy &KernelArgs, uint32_t NumThreads,
570- uint64_t NumBlocks) const override ;
569+ KernelArgsTy &KernelArgs, uint32_t NumThreads[ 3 ] ,
570+ uint32_t NumBlocks[ 3 ] ) const override ;
571571
572572 // / Get group and private segment kernel size.
573573 uint32_t getGroupSize () const { return GroupSize; }
@@ -719,7 +719,7 @@ struct AMDGPUQueueTy {
719719 // / Push a kernel launch to the queue. The kernel launch requires an output
720720 // / signal and can define an optional input signal (nullptr if none).
721721 Error pushKernelLaunch (const AMDGPUKernelTy &Kernel, void *KernelArgs,
722- uint32_t NumThreads, uint64_t NumBlocks,
722+ uint32_t NumThreads[ 3 ], uint32_t NumBlocks[ 3 ] ,
723723 uint32_t GroupSize, uint64_t StackSize,
724724 AMDGPUSignalTy *OutputSignal,
725725 AMDGPUSignalTy *InputSignal) {
@@ -746,14 +746,18 @@ struct AMDGPUQueueTy {
746746 assert (Packet && " Invalid packet" );
747747
748748 // The first 32 bits of the packet are written after the other fields
749- uint16_t Setup = UINT16_C (1 ) << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
750- Packet->workgroup_size_x = NumThreads;
751- Packet->workgroup_size_y = 1 ;
752- Packet->workgroup_size_z = 1 ;
749+ uint16_t Dims = NumBlocks[2 ] * NumThreads[2 ] > 1
750+ ? 3
751+ : 1 + (NumBlocks[1 ] * NumThreads[1 ] != 1 );
752+ uint16_t Setup = UINT16_C (Dims)
753+ << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
754+ Packet->workgroup_size_x = NumThreads[0 ];
755+ Packet->workgroup_size_y = NumThreads[1 ];
756+ Packet->workgroup_size_z = NumThreads[2 ];
753757 Packet->reserved0 = 0 ;
754- Packet->grid_size_x = NumBlocks * NumThreads;
755- Packet->grid_size_y = 1 ;
756- Packet->grid_size_z = 1 ;
758+ Packet->grid_size_x = NumBlocks[ 0 ] * NumThreads[ 0 ] ;
759+ Packet->grid_size_y = NumBlocks[ 1 ] * NumThreads[ 1 ] ;
760+ Packet->grid_size_z = NumBlocks[ 2 ] * NumThreads[ 2 ] ;
757761 Packet->private_segment_size =
758762 Kernel.usesDynamicStack () ? StackSize : Kernel.getPrivateSize ();
759763 Packet->group_segment_size = GroupSize;
@@ -1240,7 +1244,7 @@ struct AMDGPUStreamTy {
12401244 // / the kernel finalizes. Once the kernel is finished, the stream will release
12411245 // / the kernel args buffer to the specified memory manager.
12421246 Error pushKernelLaunch (const AMDGPUKernelTy &Kernel, void *KernelArgs,
1243- uint32_t NumThreads, uint64_t NumBlocks,
1247+ uint32_t NumThreads[ 3 ], uint32_t NumBlocks[ 3 ] ,
12441248 uint32_t GroupSize, uint64_t StackSize,
12451249 AMDGPUMemoryManagerTy &MemoryManager) {
12461250 if (Queue == nullptr )
@@ -2827,10 +2831,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
28272831 AsyncInfoWrapperTy AsyncInfoWrapper (*this , nullptr );
28282832
28292833 KernelArgsTy KernelArgs = {};
2830- if ( auto Err =
2831- AMDGPUKernel.launchImpl (* this , /* NumThread= */ 1u ,
2832- /* NumBlocks= */ 1ul , KernelArgs,
2833- KernelLaunchParamsTy{}, AsyncInfoWrapper))
2834+ uint32_t NumBlocksAndThreads[ 3 ] = { 1u , 1u , 1u };
2835+ if ( auto Err = AMDGPUKernel.launchImpl (
2836+ * this , NumBlocksAndThreads, NumBlocksAndThreads , KernelArgs,
2837+ KernelLaunchParamsTy{}, AsyncInfoWrapper))
28342838 return Err;
28352839
28362840 Error Err = Plugin::success ();
@@ -3328,7 +3332,7 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
33283332};
33293333
33303334Error AMDGPUKernelTy::launchImpl (GenericDeviceTy &GenericDevice,
3331- uint32_t NumThreads, uint64_t NumBlocks,
3335+ uint32_t NumThreads[ 3 ], uint32_t NumBlocks[ 3 ] ,
33323336 KernelArgsTy &KernelArgs,
33333337 KernelLaunchParamsTy LaunchParams,
33343338 AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -3385,13 +3389,15 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
33853389 // Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
33863390 if (ImplArgs &&
33873391 getImplicitArgsSize () == sizeof (hsa_utils::AMDGPUImplicitArgsTy)) {
3388- ImplArgs->BlockCountX = NumBlocks;
3389- ImplArgs->BlockCountY = 1 ;
3390- ImplArgs->BlockCountZ = 1 ;
3391- ImplArgs->GroupSizeX = NumThreads;
3392- ImplArgs->GroupSizeY = 1 ;
3393- ImplArgs->GroupSizeZ = 1 ;
3394- ImplArgs->GridDims = 1 ;
3392+ ImplArgs->BlockCountX = NumBlocks[0 ];
3393+ ImplArgs->BlockCountY = NumBlocks[1 ];
3394+ ImplArgs->BlockCountZ = NumBlocks[2 ];
3395+ ImplArgs->GroupSizeX = NumThreads[0 ];
3396+ ImplArgs->GroupSizeY = NumThreads[1 ];
3397+ ImplArgs->GroupSizeZ = NumThreads[2 ];
3398+ ImplArgs->GridDims = NumBlocks[2 ] * NumThreads[2 ] > 1
3399+ ? 3
3400+ : 1 + (NumBlocks[1 ] * NumThreads[1 ] != 1 );
33953401 ImplArgs->DynamicLdsSize = KernelArgs.DynCGroupMem ;
33963402 }
33973403
@@ -3402,8 +3408,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
34023408
34033409Error AMDGPUKernelTy::printLaunchInfoDetails (GenericDeviceTy &GenericDevice,
34043410 KernelArgsTy &KernelArgs,
3405- uint32_t NumThreads,
3406- uint64_t NumBlocks) const {
3411+ uint32_t NumThreads[ 3 ] ,
3412+ uint32_t NumBlocks[ 3 ] ) const {
34073413 // Only do all this when the output is requested
34083414 if (!(getInfoLevel () & OMP_INFOTYPE_PLUGIN_KERNEL))
34093415 return Plugin::success ();
@@ -3440,12 +3446,13 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
34403446 // S/VGPR Spill Count: how many S/VGPRs are spilled by the kernel
34413447 // Tripcount: loop tripcount for the kernel
34423448 INFO (OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId (),
3443- " #Args: %d Teams x Thrds: %4lux %4u (MaxFlatWorkGroupSize: %u) LDS "
3449+ " #Args: %d Teams x Thrds: %4ux %4u (MaxFlatWorkGroupSize: %u) LDS "
34443450 " Usage: %uB #SGPRs/VGPRs: %u/%u #SGPR/VGPR Spills: %u/%u Tripcount: "
34453451 " %lu\n " ,
3446- ArgNum, NumGroups, ThreadsPerGroup, MaxFlatWorkgroupSize,
3447- GroupSegmentSize, SGPRCount, VGPRCount, SGPRSpillCount, VGPRSpillCount,
3448- LoopTripCount);
3452+ ArgNum, NumGroups[0 ] * NumGroups[1 ] * NumGroups[2 ],
3453+ ThreadsPerGroup[0 ] * ThreadsPerGroup[1 ] * ThreadsPerGroup[2 ],
3454+ MaxFlatWorkgroupSize, GroupSegmentSize, SGPRCount, VGPRCount,
3455+ SGPRSpillCount, VGPRSpillCount, LoopTripCount);
34493456
34503457 return Plugin::success ();
34513458}
0 commit comments