@@ -1105,3 +1105,223 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) {
11051105 uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
11061106 Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
11071107}
1108+
1109+ struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
1110+ virtual void SetUp () override {
1111+ program_name = " saxpy_usm_local_mem" ;
1112+ UUR_RETURN_ON_FATAL_FAILURE (
1113+ urUpdatableCommandBufferExpExecutionTest::SetUp ());
1114+
1115+ if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
1116+ GTEST_SKIP ()
1117+ << " Local memory argument update not supported on Level Zero." ;
1118+ }
1119+
1120+ // HIP has extra args for local memory so we define an offset for arg
1121+ // indices here for updating
1122+ hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0 ;
1123+ ur_device_usm_access_capability_flags_t shared_usm_flags;
1124+ ASSERT_SUCCESS (
1125+ uur::GetDeviceUSMSingleSharedSupport (device, shared_usm_flags));
1126+ if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
1127+ GTEST_SKIP () << " Shared USM is not supported." ;
1128+ }
1129+
1130+ const size_t allocation_size =
1131+ sizeof (uint32_t ) * global_size * local_size;
1132+ for (auto &shared_ptr : shared_ptrs) {
1133+ ASSERT_SUCCESS (urUSMSharedAlloc (context, device, nullptr , nullptr ,
1134+ allocation_size, &shared_ptr));
1135+ ASSERT_NE (shared_ptr, nullptr );
1136+
1137+ std::vector<uint8_t > pattern (allocation_size);
1138+ uur::generateMemFillPattern (pattern);
1139+ std::memcpy (shared_ptr, pattern.data (), allocation_size);
1140+ }
1141+
1142+ std::array<size_t , 12 > index_order{};
1143+ if (backend != UR_PLATFORM_BACKEND_HIP) {
1144+ index_order = {3 , 2 , 4 , 5 , 1 , 0 };
1145+ } else {
1146+ index_order = {9 , 8 , 10 , 11 , 4 , 5 , 6 , 7 , 0 , 1 , 2 , 3 };
1147+ }
1148+ size_t current_index = 0 ;
1149+
1150+ // Index 3 is A
1151+ ASSERT_SUCCESS (urKernelSetArgValue (kernel, index_order[current_index++],
1152+ sizeof (A), nullptr , &A));
1153+ // Index 2 is output
1154+ ASSERT_SUCCESS (urKernelSetArgPointer (
1155+ kernel, index_order[current_index++], nullptr , shared_ptrs[0 ]));
1156+
1157+ // Index 4 is X
1158+ ASSERT_SUCCESS (urKernelSetArgPointer (
1159+ kernel, index_order[current_index++], nullptr , shared_ptrs[1 ]));
1160+ // Index 5 is Y
1161+ ASSERT_SUCCESS (urKernelSetArgPointer (
1162+ kernel, index_order[current_index++], nullptr , shared_ptrs[2 ]));
1163+
1164+ // Index 1 is local_mem_b arg
1165+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1166+ local_mem_b_size, nullptr ));
1167+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1168+ ASSERT_SUCCESS (urKernelSetArgValue (
1169+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1170+ nullptr , &hip_local_offset));
1171+ ASSERT_SUCCESS (urKernelSetArgValue (
1172+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1173+ nullptr , &hip_local_offset));
1174+ ASSERT_SUCCESS (urKernelSetArgValue (
1175+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1176+ nullptr , &hip_local_offset));
1177+ }
1178+
1179+ // Index 0 is local_mem_a arg
1180+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1181+ local_mem_a_size, nullptr ));
1182+
1183+ // Hip has extra args for local mem at index 1-3
1184+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1185+ ASSERT_SUCCESS (urKernelSetArgValue (
1186+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1187+ nullptr , &hip_local_offset));
1188+ ASSERT_SUCCESS (urKernelSetArgValue (
1189+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1190+ nullptr , &hip_local_offset));
1191+ ASSERT_SUCCESS (urKernelSetArgValue (
1192+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1193+ nullptr , &hip_local_offset));
1194+ }
1195+ }
1196+ };
1197+
1198+ struct LocalMemoryUpdateTestOutOfOrder : LocalMemoryUpdateTestBaseOutOfOrder {
1199+ void SetUp () override {
1200+ UUR_RETURN_ON_FATAL_FAILURE (
1201+ LocalMemoryUpdateTestBaseOutOfOrder::SetUp ());
1202+
1203+ // Append kernel command to command-buffer and close command-buffer
1204+ ASSERT_SUCCESS (urCommandBufferAppendKernelLaunchExp (
1205+ updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset,
1206+ &global_size, &local_size, 0 , nullptr , 0 , nullptr , 0 , nullptr ,
1207+ nullptr , nullptr , &command_handle));
1208+ ASSERT_NE (command_handle, nullptr );
1209+
1210+ ASSERT_SUCCESS (urCommandBufferFinalizeExp (updatable_cmd_buf_handle));
1211+ }
1212+
1213+ void TearDown () override {
1214+ if (command_handle) {
1215+ EXPECT_SUCCESS (urCommandBufferReleaseCommandExp (command_handle));
1216+ }
1217+
1218+ UUR_RETURN_ON_FATAL_FAILURE (
1219+ LocalMemoryUpdateTestBaseOutOfOrder::TearDown ());
1220+ }
1221+
1222+ ur_exp_command_buffer_command_handle_t command_handle = nullptr ;
1223+ };
1224+
1225+ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P (LocalMemoryUpdateTestOutOfOrder);
1226+
1227+ // Test updating A,X,Y parameters to new values and local memory to larger
1228+ // values when the kernel arguments were added out of order.
1229+ TEST_P (LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) {
1230+ // Run command-buffer prior to update and verify output
1231+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1232+ nullptr , nullptr ));
1233+ ASSERT_SUCCESS (urQueueFinish (queue));
1234+
1235+ uint32_t *output = (uint32_t *)shared_ptrs[0 ];
1236+ uint32_t *X = (uint32_t *)shared_ptrs[1 ];
1237+ uint32_t *Y = (uint32_t *)shared_ptrs[2 ];
1238+ Validate (output, X, Y, A, global_size, local_size);
1239+
1240+ // Update inputs
1241+ std::array<ur_exp_command_buffer_update_pointer_arg_desc_t , 2 >
1242+ new_input_descs;
1243+ std::array<ur_exp_command_buffer_update_value_arg_desc_t , 3 >
1244+ new_value_descs;
1245+
1246+ size_t new_local_size = local_size * 4 ;
1247+ size_t new_local_mem_a_size = new_local_size * sizeof (uint32_t );
1248+
1249+ // New local_mem_a at index 0
1250+ new_value_descs[0 ] = {
1251+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1252+ nullptr , // pNext
1253+ 0 , // argIndex
1254+ new_local_mem_a_size, // argSize
1255+ nullptr , // pProperties
1256+ nullptr , // hArgValue
1257+ };
1258+
1259+ // New local_mem_b at index 1
1260+ new_value_descs[1 ] = {
1261+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1262+ nullptr , // pNext
1263+ 1 + hip_arg_offset, // argIndex
1264+ local_mem_b_size, // argSize
1265+ nullptr , // pProperties
1266+ nullptr , // hArgValue
1267+ };
1268+
1269+ // New A at index 3
1270+ uint32_t new_A = 33 ;
1271+ new_value_descs[2 ] = {
1272+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1273+ nullptr , // pNext
1274+ 3 + (2 * hip_arg_offset), // argIndex
1275+ sizeof (new_A), // argSize
1276+ nullptr , // pProperties
1277+ &new_A, // hArgValue
1278+ };
1279+
1280+ // New X at index 4
1281+ new_input_descs[0 ] = {
1282+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1283+ nullptr , // pNext
1284+ 4 + (2 * hip_arg_offset), // argIndex
1285+ nullptr , // pProperties
1286+ &shared_ptrs[3 ], // pArgValue
1287+ };
1288+
1289+ // New Y at index 5
1290+ new_input_descs[1 ] = {
1291+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1292+ nullptr , // pNext
1293+ 5 + (2 * hip_arg_offset), // argIndex
1294+ nullptr , // pProperties
1295+ &shared_ptrs[4 ], // pArgValue
1296+ };
1297+
1298+ // Update kernel inputs
1299+ ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
1300+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
1301+ nullptr , // pNext
1302+ kernel, // hNewKernel
1303+ 0 , // numNewMemObjArgs
1304+ new_input_descs.size (), // numNewPointerArgs
1305+ new_value_descs.size (), // numNewValueArgs
1306+ n_dimensions, // newWorkDim
1307+ nullptr , // pNewMemObjArgList
1308+ new_input_descs.data (), // pNewPointerArgList
1309+ new_value_descs.data (), // pNewValueArgList
1310+ nullptr , // pNewGlobalWorkOffset
1311+ nullptr , // pNewGlobalWorkSize
1312+ nullptr , // pNewLocalWorkSize
1313+ };
1314+
1315+ // Update kernel and enqueue command-buffer again
1316+ ASSERT_SUCCESS (
1317+ urCommandBufferUpdateKernelLaunchExp (command_handle, &update_desc));
1318+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1319+ nullptr , nullptr ));
1320+ ASSERT_SUCCESS (urQueueFinish (queue));
1321+
1322+ // Verify that update occurred correctly
1323+ uint32_t *new_output = (uint32_t *)shared_ptrs[0 ];
1324+ uint32_t *new_X = (uint32_t *)shared_ptrs[3 ];
1325+ uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
1326+ Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
1327+ }
0 commit comments