@@ -21,9 +21,8 @@ namespace {
2121using lzt::to_int;
2222using lzt::to_u32;
2323
24- class CooperativeKernelTests
25- : public ::testing::Test,
26- public ::testing::WithParamInterface<std::tuple<uint32_t , bool >> {
24+ class CooperativeKernelTests : public ::testing::Test,
25+ public ::testing::WithParamInterface<bool > {
2726protected:
2827 void RunGivenCooperativeKernelWhenAppendingLaunchCooperativeKernelTest (
2928 bool is_shared_system);
@@ -32,7 +31,6 @@ class CooperativeKernelTests
3231void CooperativeKernelTests::
3332 RunGivenCooperativeKernelWhenAppendingLaunchCooperativeKernelTest (
3433 bool is_shared_system) {
35- uint32_t max_group_count = 0 ;
3634 ze_module_handle_t module = nullptr ;
3735 ze_kernel_handle_t kernel = nullptr ;
3836 auto driver = lzt::get_default_driver ();
@@ -57,20 +55,11 @@ void CooperativeKernelTests::
5755 LOG_WARNING << " No command queues that support cooperative kernels" ;
5856 GTEST_SKIP ();
5957 }
60- auto is_immediate = std::get< 1 >( GetParam () );
58+ auto is_immediate = GetParam ();
6159 auto cmd_bundle =
6260 lzt::create_command_bundle (context, device, flags, mode, priority, 0 ,
6361 to_u32 (ordinal), 0 , is_immediate);
6462
65- // Set up input vector data
66- const size_t data_size = 4096 ;
67- uint64_t kernel_data[data_size] = {0 };
68- void *input_data = lzt::allocate_shared_memory_with_allocator_selector (
69- sizeof (uint64_t ) * data_size, 1 , 0 , 0 , device, context, is_shared_system);
70-
71- memcpy (input_data, kernel_data, data_size * sizeof (uint64_t ));
72-
73- uint32_t row_num = std::get<0 >(GetParam ());
7463 uint32_t groups_x = 1 ;
7564
7665 module = lzt::create_module (context, device, " cooperative_kernel.spv" ,
@@ -83,16 +72,14 @@ void CooperativeKernelTests::
8372 ASSERT_ZE_RESULT_SUCCESS (
8473 zeKernelSuggestMaxCooperativeGroupCount (kernel, &groups_x));
8574 ASSERT_GT (groups_x, 0 );
86- // We've set the number of workgroups to the max
87- // so check that it is sufficient for the input,
88- // otherwise just clamp the test
89- if (groups_x < row_num) {
90- row_num = groups_x;
91- }
75+
76+ void *input_data = lzt::allocate_shared_memory_with_allocator_selector (
77+ sizeof ( int32_t ) * groups_x, 1 , 0 , 0 , device, context, is_shared_system);
78+
79+ memset (input_data, 0 , sizeof ( int32_t ) * groups_x) ;
80+
9281 ASSERT_ZE_RESULT_SUCCESS (
9382 zeKernelSetArgumentValue (kernel, 0 , sizeof (input_data), &input_data));
94- ASSERT_ZE_RESULT_SUCCESS (
95- zeKernelSetArgumentValue (kernel, 1 , sizeof (row_num), &row_num));
9683
9784 ze_group_count_t args = {groups_x, 1 , 1 };
9885 ASSERT_ZE_RESULT_SUCCESS (zeCommandListAppendLaunchCooperativeKernel (
@@ -102,10 +89,8 @@ void CooperativeKernelTests::
10289 lzt::execute_and_sync_command_bundle (cmd_bundle, UINT64_MAX);
10390
10491 // Validate the kernel completed successfully and correctly
105- uint64_t val = 0 ;
106- for (uint32_t i = 0U ; i <= row_num; i++) {
107- val = i + row_num;
108- ASSERT_EQ (static_cast <uint64_t *>(input_data)[i], val);
92+ for (uint32_t i = 0U ; i < groups_x; i++) {
93+ ASSERT_EQ (static_cast <int *>(input_data)[i], i);
10994 }
11095
11196 lzt::free_memory_with_allocator_selector (context, input_data,
@@ -128,10 +113,7 @@ LZT_TEST_P(
128113 RunGivenCooperativeKernelWhenAppendingLaunchCooperativeKernelTest (true );
129114}
130115
131- INSTANTIATE_TEST_SUITE_P (
132- // 62 is the max row such that no calculation will overflow max uint64 value
133- GroupNumbers, CooperativeKernelTests,
134- ::testing::Combine (::testing::Values(0 , 1 , 5 , 10 , 50 , 62 ),
135- ::testing::Bool()));
116+ INSTANTIATE_TEST_SUITE_P (GroupNumbers, CooperativeKernelTests,
117+ ::testing::Bool ());
136118
137119} // namespace
0 commit comments