|
8 | 8 | // |
9 | 9 | //===----------------------------------------------------------------------===// |
10 | 10 | #include "command_buffer.hpp" |
| 11 | +#include "helpers/kernel_helpers.hpp" |
11 | 12 | #include "logger/ur_logger.hpp" |
12 | 13 | #include "ur_level_zero.hpp" |
13 | 14 |
|
@@ -78,130 +79,6 @@ preferCopyEngineForFill(ur_exp_command_buffer_handle_t CommandBuffer, |
78 | 79 | return UR_RESULT_SUCCESS; |
79 | 80 | } |
80 | 81 |
|
81 | | -/** |
82 | | - * Calculates a work group size for the kernel based on the GlobalWorkSize or |
83 | | - * the LocalWorkSize if provided. |
84 | | - * @param[in][optional] Kernel The Kernel. Used when LocalWorkSize is not |
85 | | - * provided. |
86 | | - * @param[in][optional] Device The device associated with the kernel. Used when |
87 | | - * LocalWorkSize is not provided. |
88 | | - * @param[out] ZeThreadGroupDimensions Number of work groups in each dimension. |
89 | | - * @param[out] WG The work group size for each dimension. |
90 | | - * @param[in] WorkDim The number of dimensions in the kernel. |
91 | | - * @param[in] GlobalWorkSize The global work size. |
92 | | - * @param[in][optional] LocalWorkSize The local work size. |
93 | | - * @return UR_RESULT_SUCCESS or an error code on failure. |
94 | | - */ |
95 | | -ur_result_t calculateKernelWorkDimensions( |
96 | | - ur_kernel_handle_t Kernel, ur_device_handle_t Device, |
97 | | - ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], |
98 | | - uint32_t WorkDim, const size_t *GlobalWorkSize, |
99 | | - const size_t *LocalWorkSize) { |
100 | | - |
101 | | - UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE); |
102 | | - // If LocalWorkSize is not provided then Kernel must be provided to query |
103 | | - // suggested group size. |
104 | | - UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE); |
105 | | - |
106 | | - // New variable needed because GlobalWorkSize parameter might not be of size |
107 | | - // 3 |
108 | | - size_t GlobalWorkSize3D[3]{1, 1, 1}; |
109 | | - std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); |
110 | | - |
111 | | - if (LocalWorkSize) { |
112 | | - WG[0] = ur_cast<uint32_t>(LocalWorkSize[0]); |
113 | | - WG[1] = WorkDim >= 2 ? ur_cast<uint32_t>(LocalWorkSize[1]) : 1; |
114 | | - WG[2] = WorkDim == 3 ? ur_cast<uint32_t>(LocalWorkSize[2]) : 1; |
115 | | - } else { |
116 | | - // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D |
117 | | - // values do not fit to 32-bit that the API only supports currently. |
118 | | - bool SuggestGroupSize = true; |
119 | | - for (int I : {0, 1, 2}) { |
120 | | - if (GlobalWorkSize3D[I] > UINT32_MAX) { |
121 | | - SuggestGroupSize = false; |
122 | | - } |
123 | | - } |
124 | | - if (SuggestGroupSize) { |
125 | | - ZE2UR_CALL(zeKernelSuggestGroupSize, |
126 | | - (Kernel->ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], |
127 | | - GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); |
128 | | - } else { |
129 | | - for (int I : {0, 1, 2}) { |
130 | | - // Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is |
131 | | - // fully divisable with. Start with the max possible size in |
132 | | - // each dimension. |
133 | | - uint32_t GroupSize[] = { |
134 | | - Device->ZeDeviceComputeProperties->maxGroupSizeX, |
135 | | - Device->ZeDeviceComputeProperties->maxGroupSizeY, |
136 | | - Device->ZeDeviceComputeProperties->maxGroupSizeZ}; |
137 | | - GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); |
138 | | - while (GlobalWorkSize3D[I] % GroupSize[I]) { |
139 | | - --GroupSize[I]; |
140 | | - } |
141 | | - if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) { |
142 | | - logger::debug("calculateKernelWorkDimensions: can't find a WG size " |
143 | | - "suitable for global work size > UINT32_MAX"); |
144 | | - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
145 | | - } |
146 | | - WG[I] = GroupSize[I]; |
147 | | - } |
148 | | - logger::debug("calculateKernelWorkDimensions: using computed WG " |
149 | | - "size = {{{}, {}, {}}}", |
150 | | - WG[0], WG[1], WG[2]); |
151 | | - } |
152 | | - } |
153 | | - |
154 | | - // TODO: assert if sizes do not fit into 32-bit? |
155 | | - switch (WorkDim) { |
156 | | - case 3: |
157 | | - ZeThreadGroupDimensions.groupCountX = |
158 | | - ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]); |
159 | | - ZeThreadGroupDimensions.groupCountY = |
160 | | - ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]); |
161 | | - ZeThreadGroupDimensions.groupCountZ = |
162 | | - ur_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]); |
163 | | - break; |
164 | | - case 2: |
165 | | - ZeThreadGroupDimensions.groupCountX = |
166 | | - ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]); |
167 | | - ZeThreadGroupDimensions.groupCountY = |
168 | | - ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]); |
169 | | - WG[2] = 1; |
170 | | - break; |
171 | | - case 1: |
172 | | - ZeThreadGroupDimensions.groupCountX = |
173 | | - ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]); |
174 | | - WG[1] = WG[2] = 1; |
175 | | - break; |
176 | | - |
177 | | - default: |
178 | | - logger::error("calculateKernelWorkDimensions: unsupported work_dim"); |
179 | | - return UR_RESULT_ERROR_INVALID_VALUE; |
180 | | - } |
181 | | - |
182 | | - // Error handling for non-uniform group size case |
183 | | - if (GlobalWorkSize3D[0] != |
184 | | - size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { |
185 | | - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " |
186 | | - "is not a multiple of the group size in the 1st dimension"); |
187 | | - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
188 | | - } |
189 | | - if (GlobalWorkSize3D[1] != |
190 | | - size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { |
191 | | - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " |
192 | | - "is not a multiple of the group size in the 2nd dimension"); |
193 | | - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
194 | | - } |
195 | | - if (GlobalWorkSize3D[2] != |
196 | | - size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { |
197 | | - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " |
198 | | - "is not a multiple of the group size in the 3rd dimension"); |
199 | | - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
200 | | - } |
201 | | - |
202 | | - return UR_RESULT_SUCCESS; |
203 | | -} |
204 | | - |
205 | 82 | /** |
206 | 83 | * Helper function for finding the Level Zero events associated with the |
207 | 84 | * commands in a command-buffer, each event is pointed to by a sync-point in the |
@@ -880,7 +757,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( |
880 | 757 |
|
881 | 758 | ze_group_count_t ZeThreadGroupDimensions{1, 1, 1}; |
882 | 759 | uint32_t WG[3]; |
883 | | - UR_CALL(calculateKernelWorkDimensions(Kernel, CommandBuffer->Device, |
| 760 | + UR_CALL(calculateKernelWorkDimensions(Kernel->ZeKernel, CommandBuffer->Device, |
884 | 761 | ZeThreadGroupDimensions, WG, WorkDim, |
885 | 762 | GlobalWorkSize, LocalWorkSize)); |
886 | 763 |
|
@@ -1587,8 +1464,8 @@ ur_result_t updateKernelCommand( |
1587 | 1464 |
|
1588 | 1465 | uint32_t WG[3]; |
1589 | 1466 | UR_CALL(calculateKernelWorkDimensions( |
1590 | | - Command->Kernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG, |
1591 | | - Dim, NewGlobalWorkSize, NewLocalWorkSize)); |
| 1467 | + Command->Kernel->ZeKernel, CommandBuffer->Device, |
| 1468 | + ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize)); |
1592 | 1469 |
|
1593 | 1470 | auto MutableGroupCountDesc = |
1594 | 1471 | std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>(); |
|
0 commit comments