Skip to content
199 changes: 199 additions & 0 deletions help_function/src/util_group_load_store_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// ====------ onedpl_test_group_load.cpp------------ *- C++ -* ----===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// ====------ onedpl_test_group_load.cpp------------ *- C++ -* ----===//
// ====------ util_group_load_store_test.cpp------------ *- C++ -* ----===//


//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//
// ===----------------------------------------------------------------------===//
#include <dpct/dpct.hpp>
#include <dpct/dpl_utils.hpp>
#include <iostream>
#include <oneapi/dpl/iterator>
#include <sycl/sycl.hpp>

template <dpct::group::load_algorithm T>
bool helper_validation_function(const int *ptr, const char *func_name) {
if constexpr (T == dpct::group::load_algorithm::BLOCK_LOAD_DIRECT) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the only difference between verification for "direct" and "striped" load algorithms is the name. This makes sense because we should "unmap" any "map" via the store operation as long as they match.

However, it doesn't make sense to just have a branch and repeat the code. You can branch for the name, and run the same verification, or just embed the info you want in func_name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that makes sense, thanks

for (int i = 0; i < 512; ++i) {
if (ptr[i] != i) {
std::cout << func_name << "_blocked"
<< " failed\n";
std::ostream_iterator<int> Iter(std::cout, ", ");
std::copy(ptr, ptr + 512, Iter);
std::cout << std::endl;
return false;
}
}
std::cout << func_name << "_blocked"
<< " pass\n";
} else {
for (int i = 0; i < 512; ++i) {
if (ptr[i] != i) {
std::cout << func_name << "_striped"
<< " failed\n";
std::ostream_iterator<int> Iter(std::cout, ", ");
std::copy(ptr, ptr + 512, Iter);
std::cout << std::endl;
return false;
}
}
std::cout << func_name << "_striped"
<< " pass\n";
}
return true;
}

bool subgroup_helper_validation_function(const int *ptr, const uint32_t *sg_sz,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this has the same verification as well (as it should), but it can be consolidated into a single routine.

const char *func_name) {
for (int i = 0; i < 512; ++i) {
if (ptr[i] != i) {
std::cout << " failed\n";
std::ostream_iterator<int> Iter(std::cout, ", ");
std::copy(ptr, ptr + 512, Iter);
std::cout << std::endl;
return false;
}
}

std::cout << func_name << " pass\n";
return true;
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store() {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED in its entirety as API
// functions
// Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT
// dpct::group::store_algorithm::BLOCK_STORE_STRIPED in its entirety as API
// functions
sycl::queue q(dpct::get_default_queue());
oneapi::dpl::counting_iterator<int> count_it(0);
sycl::buffer<int, 1> buffer(count_it, count_it + 512);

q.submit([&](sycl::handler &h) {
using group_load =
dpct::group::workgroup_load<4, T, int, const int *, sycl::nd_item<3>>;
using group_store =
dpct::group::workgroup_store<4, S, int, int *, sycl::nd_item<3>>;
size_t temp_storage_size = group_load::get_local_memory_size(128);
sycl::local_accessor<uint8_t, 1> tacc(sycl::range<1>(temp_storage_size), h);
sycl::accessor data_accessor_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
data_accessor_read_write.get_multi_ptr<sycl::access::decorated::yes>()
.get();
auto *tmp = tacc.get_multi_ptr<sycl::access::decorated::yes>().get();
// Load thread_data of each work item to blocked arrangement
group_load(tmp).load(item, d_r_w, thread_data);
// Store thread_data of each work item from blocked arrangement
group_store(tmp).store(item, d_r_w, thread_data);
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_write);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function<T>(ptr, "test_group_load_store");
}

bool test_load_store_subgroup_striped_standalone() {
// Tests dpct::group::load_subgroup_striped as standalone method
sycl::queue q(dpct::get_default_queue());
int data[512];
for (int i = 0; i < 512; i++)
data[i] = i;
sycl::buffer<int, 1> buffer(data, 512);
sycl::buffer<uint32_t, 1> sg_sz_buf{sycl::range<1>(1)};

q.submit([&](sycl::handler &h) {
sycl::accessor dacc_read_write(buffer, h, sycl::read_write);
sycl::accessor sg_sz_dacc(sg_sz_buf, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
dacc_read_write.get_multi_ptr<sycl::access::decorated::yes>().get();
auto *sg_sz_acc =
sg_sz_dacc.get_multi_ptr<sycl::access::decorated::yes>().get();
size_t gid = item.get_global_linear_id();
if (gid == 0) {
sg_sz_acc[0] = item.get_sub_group().get_local_linear_range();
}
dpct::group::uninitialized_load_subgroup_striped<4, int>(item, d_r_w,
thread_data);
dpct::group::store_subgroup_striped<4, int>(item, d_r_w, thread_data);
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_only);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
sycl::host_accessor data_accessor_sg(sg_sz_buf, sycl::read_only);
const uint32_t *ptr_sg =
data_accessor_sg.get_multi_ptr<sycl::access::decorated::yes>();
return subgroup_helper_validation_function(
ptr, ptr_sg, "test_subgroup_striped_standalone");
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_standalone() {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT &
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED as standalone methods
// Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT &
// dpct::group::store_algorithm::BLOCK_STORE_STRIPED as standalone methods
sycl::queue q(dpct::get_default_queue());
int data[512];
for (int i = 0; i < 512; i++)
data[i] = i;
sycl::buffer<int, 1> buffer(data, 512);

q.submit([&](sycl::handler &h) {
sycl::accessor dacc_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
dacc_read_write.get_multi_ptr<sycl::access::decorated::yes>().get();
// Load thread_data of each work item to blocked arrangement
if (T == dpct::group::load_algorithm::BLOCK_LOAD_DIRECT) {
dpct::group::load_blocked<4, int>(item, d_r_w, thread_data);
} else {
dpct::group::load_striped<4, int>(item, d_r_w, thread_data);
}
// Store thread_data of each work item from blocked arrangement
if (S == dpct::group::store_algorithm::BLOCK_STORE_DIRECT) {
dpct::group::store_blocked<4, int>(item, d_r_w, thread_data);
} else {
dpct::group::store_striped<4, int>(item, d_r_w, thread_data);
}
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_only);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function<T>(ptr, "test_group_load_store");
}

int main() {

return !(
// Calls test_group_load with blocked and striped strategies , should pass
// both results.
test_group_load_store<dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>() &&
test_group_load_store<dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
// Calls test_load_subgroup_striped_standalone and should pass
test_load_store_subgroup_striped_standalone() &&
// Calls test_group_load_standalone with blocked and striped strategies as
// free functions, should pass both results.
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>());
}