diff --git a/sycl/test-e2e/GroupAlgorithm/root_group.cpp b/sycl/test-e2e/GroupAlgorithm/root_group.cpp index 2e50634fd21c8..80cd8b286b68b 100644 --- a/sycl/test-e2e/GroupAlgorithm/root_group.cpp +++ b/sycl/test-e2e/GroupAlgorithm/root_group.cpp @@ -18,136 +18,22 @@ #include #include -static constexpr int WorkGroupSize = 32; - -void testFeatureMacro() { - static_assert(SYCL_EXT_ONEAPI_ROOT_GROUP == 1, - "SYCL_EXT_ONEAPI_ROOT_GROUP must have a value of 1"); -} - -void testQueriesAndProperties() { - sycl::queue q; - const auto bundle = - sycl::get_kernel_bundle(q.get_context()); - const auto kernel = bundle.get_kernel(); - const auto local_range = sycl::range<1>(1); - const auto maxWGs = - kernel - .ext_oneapi_get_info( - q, local_range, 0); - const auto wgRange = sycl::range<3>{WorkGroupSize, 1, 1}; - const auto maxWGsWithLimits = - kernel - .ext_oneapi_get_info( - q, wgRange, wgRange.size() * sizeof(int)); - const auto props = sycl::ext::oneapi::experimental::properties{ - sycl::ext::oneapi::experimental::use_root_sync}; - q.single_task(props, []() {}); - - static auto check_max_num_work_group_sync = [](auto Result) { - static_assert(std::is_same_v, size_t>, - "max_num_work_group_sync query must return size_t"); - assert(Result >= 1 && "max_num_work_group_sync query failed"); - }; - check_max_num_work_group_sync(maxWGs); - check_max_num_work_group_sync(maxWGsWithLimits); -} - -void testRootGroup() { - sycl::queue q; - const auto bundle = - sycl::get_kernel_bundle(q.get_context()); - const auto kernel = bundle.get_kernel(); - const auto maxWGs = - kernel - .ext_oneapi_get_info( - q, WorkGroupSize, 0); - const auto props = sycl::ext::oneapi::experimental::properties{ - sycl::ext::oneapi::experimental::use_root_sync}; - sycl::buffer dataBuf{sycl::range{maxWGs * WorkGroupSize}}; - const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize}; - q.submit([&](sycl::handler &h) { - sycl::accessor data{dataBuf, h}; - h.parallel_for< - class RootGroupKernel>(range, props, [=](sycl::nd_item<1> it) { - volatile float X = 1.0f; - volatile float Y = 1.0f; - auto root = it.ext_oneapi_get_root_group(); - data[root.get_local_id()] = root.get_local_id(); - sycl::group_barrier(root); - // Delay half of the workgroups with extra work to check that the barrier - // synchronizes the whole device. - if (it.get_group(0) % 2 == 0) { - X += sycl::sin(X); - Y += sycl::cos(Y); - } - root = - sycl::ext::oneapi::experimental::this_work_item::get_root_group<1>(); - int sum = data[root.get_local_id()] + - data[root.get_local_range() - root.get_local_id() - 1]; - sycl::group_barrier(root); - data[root.get_local_id()] = sum; - }); - }); - sycl::host_accessor data{dataBuf}; - const int workItemCount = static_cast(range.get_global_range().size()); - for (int i = 0; i < workItemCount; i++) { - assert(data[i] == (workItemCount - 1)); +struct RootGroupKernel { + RootGroupKernel() {} + void operator()(sycl::nd_item<1> it) const { + auto root = it.ext_oneapi_get_root_group(); + sycl::group_barrier(root); } -} - -void testRootGroupFunctions() { - sycl::queue q; - const auto bundle = - sycl::get_kernel_bundle(q.get_context()); - const auto kernel = bundle.get_kernel(); - const auto maxWGs = - kernel - .ext_oneapi_get_info( - q, WorkGroupSize, 0); - const auto props = sycl::ext::oneapi::experimental::properties{ - sycl::ext::oneapi::experimental::use_root_sync}; - - constexpr int testCount = 9; - sycl::buffer testResultsBuf{sycl::range{testCount}}; - const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize}; - q.submit([&](sycl::handler &h) { - sycl::accessor testResults{testResultsBuf, h}; - h.parallel_for( - range, props, [=](sycl::nd_item<1> it) { - const auto root = it.ext_oneapi_get_root_group(); - if (root.leader() || root.get_local_id() == 3) { - testResults[0] = root.get_group_id() == sycl::id<1>(0); - testResults[1] = root.leader() - ? root.get_local_id() == sycl::id<1>(0) - : root.get_local_id() == sycl::id<1>(3); - testResults[2] = root.get_group_range() == sycl::range<1>(1); - testResults[3] = root.get_local_range() == it.get_global_range(); - testResults[4] = - root.get_max_local_range() == root.get_local_range(); - testResults[5] = root.get_group_linear_id() == 0; - testResults[6] = - root.get_local_linear_id() == root.get_local_id().get(0); - testResults[7] = root.get_group_linear_range() == 1; - testResults[8] = - root.get_local_linear_range() == root.get_local_range().size(); - } - }); - }); - sycl::host_accessor testResults{testResultsBuf}; - for (int i = 0; i < testCount; i++) { - assert(testResults[i]); + auto get(sycl::ext::oneapi::experimental::properties_tag) { + return sycl::ext::oneapi::experimental::properties{ + sycl::ext::oneapi::experimental::use_root_sync}; } -} +}; int main() { - testFeatureMacro(); - testQueriesAndProperties(); - testRootGroup(); - testRootGroupFunctions(); + sycl::queue q; + sycl::range<1> R1{1}; + sycl::nd_range<1> NDR1{R1, R1}; + q.submit([&](sycl::handler &h) { h.parallel_for(NDR1, RootGroupKernel()); }); return EXIT_SUCCESS; -} +} \ No newline at end of file