@@ -86,15 +86,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
8686 ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
8787 uint32_t WG[3 ]{};
8888
89- // global_work_size of unused dimensions must be set to 1
90- if (WorkDim >= 2 ) {
91- UR_ASSERT (WorkDim >= 2 || GlobalWorkSize[1 ] == 1 ,
92- UR_RESULT_ERROR_INVALID_VALUE);
93- if (WorkDim == 3 ) {
94- UR_ASSERT (WorkDim == 3 || GlobalWorkSize[2 ] == 1 ,
95- UR_RESULT_ERROR_INVALID_VALUE);
96- }
97- }
89+ // New variable needed because GlobalWorkSize parameter might not be of size 3
90+ size_t GlobalWorkSize3D[3 ]{1 , 1 , 1 };
91+ std::copy (GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
92+
9893 if (LocalWorkSize) {
9994 // L0
10095 UR_ASSERT (LocalWorkSize[0 ] < (std::numeric_limits<uint32_t >::max)(),
@@ -111,14 +106,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
111106 // values do not fit to 32-bit that the API only supports currently.
112107 bool SuggestGroupSize = true ;
113108 for (int I : {0 , 1 , 2 }) {
114- if (GlobalWorkSize [I] > UINT32_MAX) {
109+ if (GlobalWorkSize3D [I] > UINT32_MAX) {
115110 SuggestGroupSize = false ;
116111 }
117112 }
118113 if (SuggestGroupSize) {
119114 ZE2UR_CALL (zeKernelSuggestGroupSize,
120- (ZeKernel, GlobalWorkSize [0 ], GlobalWorkSize [1 ],
121- GlobalWorkSize [2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
115+ (ZeKernel, GlobalWorkSize3D [0 ], GlobalWorkSize3D [1 ],
116+ GlobalWorkSize3D [2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
122117 } else {
123118 for (int I : {0 , 1 , 2 }) {
124119 // Try to find a I-dimension WG size that the GlobalWorkSize[I] is
@@ -128,11 +123,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
128123 Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeX ,
129124 Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeY ,
130125 Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeZ };
131- GroupSize[I] = (std::min)(size_t (GroupSize[I]), GlobalWorkSize [I]);
132- while (GlobalWorkSize [I] % GroupSize[I]) {
126+ GroupSize[I] = (std::min)(size_t (GroupSize[I]), GlobalWorkSize3D [I]);
127+ while (GlobalWorkSize3D [I] % GroupSize[I]) {
133128 --GroupSize[I];
134129 }
135- if (GlobalWorkSize [I] / GroupSize[I] > UINT32_MAX) {
130+ if (GlobalWorkSize3D [I] / GroupSize[I] > UINT32_MAX) {
136131 urPrint (" urEnqueueKernelLaunch: can't find a WG size "
137132 " suitable for global work size > UINT32_MAX\n " );
138133 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
@@ -149,22 +144,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
149144 switch (WorkDim) {
150145 case 3 :
151146 ZeThreadGroupDimensions.groupCountX =
152- static_cast <uint32_t >(GlobalWorkSize [0 ] / WG[0 ]);
147+ static_cast <uint32_t >(GlobalWorkSize3D [0 ] / WG[0 ]);
153148 ZeThreadGroupDimensions.groupCountY =
154- static_cast <uint32_t >(GlobalWorkSize [1 ] / WG[1 ]);
149+ static_cast <uint32_t >(GlobalWorkSize3D [1 ] / WG[1 ]);
155150 ZeThreadGroupDimensions.groupCountZ =
156- static_cast <uint32_t >(GlobalWorkSize [2 ] / WG[2 ]);
151+ static_cast <uint32_t >(GlobalWorkSize3D [2 ] / WG[2 ]);
157152 break ;
158153 case 2 :
159154 ZeThreadGroupDimensions.groupCountX =
160- static_cast <uint32_t >(GlobalWorkSize [0 ] / WG[0 ]);
155+ static_cast <uint32_t >(GlobalWorkSize3D [0 ] / WG[0 ]);
161156 ZeThreadGroupDimensions.groupCountY =
162- static_cast <uint32_t >(GlobalWorkSize [1 ] / WG[1 ]);
157+ static_cast <uint32_t >(GlobalWorkSize3D [1 ] / WG[1 ]);
163158 WG[2 ] = 1 ;
164159 break ;
165160 case 1 :
166161 ZeThreadGroupDimensions.groupCountX =
167- static_cast <uint32_t >(GlobalWorkSize [0 ] / WG[0 ]);
162+ static_cast <uint32_t >(GlobalWorkSize3D [0 ] / WG[0 ]);
168163 WG[1 ] = WG[2 ] = 1 ;
169164 break ;
170165
@@ -174,19 +169,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
174169 }
175170
176171 // Error handling for non-uniform group size case
177- if (GlobalWorkSize [0 ] !=
172+ if (GlobalWorkSize3D [0 ] !=
178173 size_t (ZeThreadGroupDimensions.groupCountX ) * WG[0 ]) {
179174 urPrint (" urEnqueueKernelLaunch: invalid work_dim. The range is not a "
180175 " multiple of the group size in the 1st dimension\n " );
181176 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
182177 }
183- if (GlobalWorkSize [1 ] !=
178+ if (GlobalWorkSize3D [1 ] !=
184179 size_t (ZeThreadGroupDimensions.groupCountY ) * WG[1 ]) {
185180 urPrint (" urEnqueueKernelLaunch: invalid work_dim. The range is not a "
186181 " multiple of the group size in the 2nd dimension\n " );
187182 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
188183 }
189- if (GlobalWorkSize [2 ] !=
184+ if (GlobalWorkSize3D [2 ] !=
190185 size_t (ZeThreadGroupDimensions.groupCountZ ) * WG[2 ]) {
191186 urPrint (" urEnqueueKernelLaunch: invalid work_dim. The range is not a "
192187 " multiple of the group size in the 3rd dimension\n " );
@@ -450,10 +445,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
450445 }
451446
452447 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
453- for ( auto It : Kernel->ZeKernelMap ) {
454- auto ZeKernel = It. second ;
448+ if ( Kernel->ZeKernelMap . empty () ) {
449+ auto ZeKernel = Kernel-> ZeKernel ;
455450 ZE2UR_CALL (zeKernelSetArgumentValue,
456451 (ZeKernel, ArgIndex, ArgSize, PArgValue));
452+ } else {
453+ for (auto It : Kernel->ZeKernelMap ) {
454+ auto ZeKernel = It.second ;
455+ ZE2UR_CALL (zeKernelSetArgumentValue,
456+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
457+ }
457458 }
458459
459460 return UR_RESULT_SUCCESS;
0 commit comments