@@ -1094,3 +1094,223 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) {
10941094 uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
10951095 Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
10961096}
1097+
1098+ struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
1099+ virtual void SetUp () override {
1100+ program_name = " saxpy_usm_local_mem" ;
1101+ UUR_RETURN_ON_FATAL_FAILURE (
1102+ urUpdatableCommandBufferExpExecutionTest::SetUp ());
1103+
1104+ if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
1105+ GTEST_SKIP ()
1106+ << " Local memory argument update not supported on Level Zero." ;
1107+ }
1108+
1109+ // HIP has extra args for local memory so we define an offset for arg
1110+ // indices here for updating
1111+ hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0 ;
1112+ ur_device_usm_access_capability_flags_t shared_usm_flags;
1113+ ASSERT_SUCCESS (
1114+ uur::GetDeviceUSMSingleSharedSupport (device, shared_usm_flags));
1115+ if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
1116+ GTEST_SKIP () << " Shared USM is not supported." ;
1117+ }
1118+
1119+ const size_t allocation_size =
1120+ sizeof (uint32_t ) * global_size * local_size;
1121+ for (auto &shared_ptr : shared_ptrs) {
1122+ ASSERT_SUCCESS (urUSMSharedAlloc (context, device, nullptr , nullptr ,
1123+ allocation_size, &shared_ptr));
1124+ ASSERT_NE (shared_ptr, nullptr );
1125+
1126+ std::vector<uint8_t > pattern (allocation_size);
1127+ uur::generateMemFillPattern (pattern);
1128+ std::memcpy (shared_ptr, pattern.data (), allocation_size);
1129+ }
1130+
1131+ std::array<size_t , 12 > index_order{};
1132+ if (backend != UR_PLATFORM_BACKEND_HIP) {
1133+ index_order = {3 , 2 , 4 , 5 , 1 , 0 };
1134+ } else {
1135+ index_order = {9 , 8 , 10 , 11 , 4 , 5 , 6 , 7 , 0 , 1 , 2 , 3 };
1136+ }
1137+ size_t current_index = 0 ;
1138+
1139+ // Index 3 is A
1140+ ASSERT_SUCCESS (urKernelSetArgValue (kernel, index_order[current_index++],
1141+ sizeof (A), nullptr , &A));
1142+ // Index 2 is output
1143+ ASSERT_SUCCESS (urKernelSetArgPointer (
1144+ kernel, index_order[current_index++], nullptr , shared_ptrs[0 ]));
1145+
1146+ // Index 4 is X
1147+ ASSERT_SUCCESS (urKernelSetArgPointer (
1148+ kernel, index_order[current_index++], nullptr , shared_ptrs[1 ]));
1149+ // Index 5 is Y
1150+ ASSERT_SUCCESS (urKernelSetArgPointer (
1151+ kernel, index_order[current_index++], nullptr , shared_ptrs[2 ]));
1152+
1153+ // Index 1 is local_mem_b arg
1154+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1155+ local_mem_b_size, nullptr ));
1156+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1157+ ASSERT_SUCCESS (urKernelSetArgValue (
1158+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1159+ nullptr , &hip_local_offset));
1160+ ASSERT_SUCCESS (urKernelSetArgValue (
1161+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1162+ nullptr , &hip_local_offset));
1163+ ASSERT_SUCCESS (urKernelSetArgValue (
1164+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1165+ nullptr , &hip_local_offset));
1166+ }
1167+
1168+ // Index 0 is local_mem_a arg
1169+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1170+ local_mem_a_size, nullptr ));
1171+
1172+ // Hip has extra args for local mem at index 1-3
1173+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1174+ ASSERT_SUCCESS (urKernelSetArgValue (
1175+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1176+ nullptr , &hip_local_offset));
1177+ ASSERT_SUCCESS (urKernelSetArgValue (
1178+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1179+ nullptr , &hip_local_offset));
1180+ ASSERT_SUCCESS (urKernelSetArgValue (
1181+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1182+ nullptr , &hip_local_offset));
1183+ }
1184+ }
1185+ };
1186+
1187+ struct LocalMemoryUpdateTestOutOfOrder : LocalMemoryUpdateTestBaseOutOfOrder {
1188+ void SetUp () override {
1189+ UUR_RETURN_ON_FATAL_FAILURE (
1190+ LocalMemoryUpdateTestBaseOutOfOrder::SetUp ());
1191+
1192+ // Append kernel command to command-buffer and close command-buffer
1193+ ASSERT_SUCCESS (urCommandBufferAppendKernelLaunchExp (
1194+ updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset,
1195+ &global_size, &local_size, 0 , nullptr , 0 , nullptr , 0 , nullptr ,
1196+ nullptr , nullptr , &command_handle));
1197+ ASSERT_NE (command_handle, nullptr );
1198+
1199+ ASSERT_SUCCESS (urCommandBufferFinalizeExp (updatable_cmd_buf_handle));
1200+ }
1201+
1202+ void TearDown () override {
1203+ if (command_handle) {
1204+ EXPECT_SUCCESS (urCommandBufferReleaseCommandExp (command_handle));
1205+ }
1206+
1207+ UUR_RETURN_ON_FATAL_FAILURE (
1208+ LocalMemoryUpdateTestBaseOutOfOrder::TearDown ());
1209+ }
1210+
1211+ ur_exp_command_buffer_command_handle_t command_handle = nullptr ;
1212+ };
1213+
1214+ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P (LocalMemoryUpdateTestOutOfOrder);
1215+
1216+ // Test updating A,X,Y parameters to new values and local memory to larger
1217+ // values when the kernel arguments were added out of order.
1218+ TEST_P (LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) {
1219+ // Run command-buffer prior to update and verify output
1220+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1221+ nullptr , nullptr ));
1222+ ASSERT_SUCCESS (urQueueFinish (queue));
1223+
1224+ uint32_t *output = (uint32_t *)shared_ptrs[0 ];
1225+ uint32_t *X = (uint32_t *)shared_ptrs[1 ];
1226+ uint32_t *Y = (uint32_t *)shared_ptrs[2 ];
1227+ Validate (output, X, Y, A, global_size, local_size);
1228+
1229+ // Update inputs
1230+ std::array<ur_exp_command_buffer_update_pointer_arg_desc_t , 2 >
1231+ new_input_descs;
1232+ std::array<ur_exp_command_buffer_update_value_arg_desc_t , 3 >
1233+ new_value_descs;
1234+
1235+ size_t new_local_size = local_size * 4 ;
1236+ size_t new_local_mem_a_size = new_local_size * sizeof (uint32_t );
1237+
1238+ // New local_mem_a at index 0
1239+ new_value_descs[0 ] = {
1240+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1241+ nullptr , // pNext
1242+ 0 , // argIndex
1243+ new_local_mem_a_size, // argSize
1244+ nullptr , // pProperties
1245+ nullptr , // hArgValue
1246+ };
1247+
1248+ // New local_mem_b at index 1
1249+ new_value_descs[1 ] = {
1250+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1251+ nullptr , // pNext
1252+ 1 + hip_arg_offset, // argIndex
1253+ local_mem_b_size, // argSize
1254+ nullptr , // pProperties
1255+ nullptr , // hArgValue
1256+ };
1257+
1258+ // New A at index 3
1259+ uint32_t new_A = 33 ;
1260+ new_value_descs[2 ] = {
1261+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1262+ nullptr , // pNext
1263+ 3 + (2 * hip_arg_offset), // argIndex
1264+ sizeof (new_A), // argSize
1265+ nullptr , // pProperties
1266+ &new_A, // hArgValue
1267+ };
1268+
1269+ // New X at index 4
1270+ new_input_descs[0 ] = {
1271+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1272+ nullptr , // pNext
1273+ 4 + (2 * hip_arg_offset), // argIndex
1274+ nullptr , // pProperties
1275+ &shared_ptrs[3 ], // pArgValue
1276+ };
1277+
1278+ // New Y at index 5
1279+ new_input_descs[1 ] = {
1280+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1281+ nullptr , // pNext
1282+ 5 + (2 * hip_arg_offset), // argIndex
1283+ nullptr , // pProperties
1284+ &shared_ptrs[4 ], // pArgValue
1285+ };
1286+
1287+ // Update kernel inputs
1288+ ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
1289+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
1290+ nullptr , // pNext
1291+ kernel, // hNewKernel
1292+ 0 , // numNewMemObjArgs
1293+ new_input_descs.size (), // numNewPointerArgs
1294+ new_value_descs.size (), // numNewValueArgs
1295+ n_dimensions, // newWorkDim
1296+ nullptr , // pNewMemObjArgList
1297+ new_input_descs.data (), // pNewPointerArgList
1298+ new_value_descs.data (), // pNewValueArgList
1299+ nullptr , // pNewGlobalWorkOffset
1300+ nullptr , // pNewGlobalWorkSize
1301+ nullptr , // pNewLocalWorkSize
1302+ };
1303+
1304+ // Update kernel and enqueue command-buffer again
1305+ ASSERT_SUCCESS (
1306+ urCommandBufferUpdateKernelLaunchExp (command_handle, &update_desc));
1307+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1308+ nullptr , nullptr ));
1309+ ASSERT_SUCCESS (urQueueFinish (queue));
1310+
1311+ // Verify that update occurred correctly
1312+ uint32_t *new_output = (uint32_t *)shared_ptrs[0 ];
1313+ uint32_t *new_X = (uint32_t *)shared_ptrs[3 ];
1314+ uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
1315+ Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
1316+ }
0 commit comments