Skip to content

Commit d7b4238

Browse files
Fixed add_contig_impl and add_matrix_vector_broadcasting_contig_impl
Corrected/added checks for validity of sub-groups reads/writes. Added -fno-approx-func flag to compile element-wise functions, as well as -fno-finite-math-only flag. Fixed test_cos_order test to account for NumPy using float16 for intermediate computations for inputs of type "i1", but CPU RT does not support float16.
1 parent 2d2bfc2 commit d7b4238

File tree

3 files changed

+35
-18
lines changed

3 files changed

+35
-18
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ if (WIN32)
5353
endif()
5454
set_source_files_properties(
5555
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
56-
PROPERTIES COMPILE_OPTIONS "${_clang_prefx}-fno-approx-func")
56+
PROPERTIES COMPILE_OPTIONS "${_clang_prefx}-fno-approx-func;${_clang_prefx}-fno-finite-math-only")
5757
target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)
5858
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
5959
if(UNIX)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ struct AddContigFunctor
6666
(ndit.get_group(0) * ndit.get_local_range(0) +
6767
sg.get_group_id()[0] * maxsgSize);
6868

69-
if (base + n_vecs * vec_sz < nelems_) {
69+
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
70+
(sgSize == maxsgSize)) {
7071
using in_ptrT1 =
7172
sycl::multi_ptr<const argT1,
7273
sycl::access::address_space::global_space>;
@@ -428,7 +429,8 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
428429
cgh.depends_on(make_padded_vec_ev);
429430

430431
auto lwsRange = sycl::range<1>(lws);
431-
size_t n_groups = (n0 * n1 + lws - 1) / lws;
432+
size_t n_elems = n0 * n1;
433+
size_t n_groups = (n_elems + lws - 1) / lws;
432434
auto gwsRange = sycl::range<1>(n_groups * lws);
433435

434436
cgh.parallel_for<class add_matrix_vector_broadcast_sg_krn<argT1, argT2, resT>>(
@@ -438,24 +440,31 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
438440
auto sg = ndit.get_sub_group();
439441
size_t gid = ndit.get_global_linear_id();
440442

443+
std::uint8_t sgSize = sg.get_local_range()[0];
441444
size_t base = gid - sg.get_local_id()[0];
442445

443-
using in_ptrT1 =
444-
sycl::multi_ptr<const argT1,
445-
sycl::access::address_space::global_space>;
446-
using in_ptrT2 =
447-
sycl::multi_ptr<const argT2,
448-
sycl::access::address_space::global_space>;
449-
using res_ptrT =
450-
sycl::multi_ptr<resT,
451-
sycl::access::address_space::global_space>;
446+
if (base + sgSize < n_elems) {
447+
using in_ptrT1 = sycl::multi_ptr<
448+
const argT1, sycl::access::address_space::global_space>;
449+
using in_ptrT2 = sycl::multi_ptr<
450+
const argT2, sycl::access::address_space::global_space>;
451+
using res_ptrT = sycl::multi_ptr<
452+
resT, sycl::access::address_space::global_space>;
452453

453-
const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
454-
const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1]));
454+
const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
455+
const argT2 vec_el =
456+
sg.load(in_ptrT2(&padded_vec[base % n1]));
455457

456-
resT res_el = mat_el + vec_el;
458+
resT res_el = mat_el + vec_el;
457459

458-
sg.store(res_ptrT(&res[base]), res_el);
460+
sg.store(res_ptrT(&res[base]), res_el);
461+
}
462+
else {
463+
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
464+
k += sgSize) {
465+
res[k] = mat[k] + padded_vec[k % n1];
466+
}
467+
}
459468
}
460469
);
461470
});

dpctl/tests/test_tensor_elementwise.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ def test_cos_usm_type(usm_type):
290290
expected_Y = np.empty(input_shape, dtype=arg_dt)
291291
expected_Y[..., 0::2] = np.cos(np.float32(np.pi / 6))
292292
expected_Y[..., 1::2] = np.cos(np.float32(np.pi / 3))
293-
assert np.allclose(dpt.asnumpy(Y), expected_Y)
293+
tol = 8 * dpt.finfo(Y.dtype).resolution
294+
295+
np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
294296

295297

296298
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -309,7 +311,13 @@ def test_cos_order(dtype):
309311
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
310312
Y = dpt.cos(U, order=ord)
311313
expected_Y = np.cos(dpt.asnumpy(U))
312-
assert np.allclose(dpt.asnumpy(Y), expected_Y)
314+
tol = 8 * max(
315+
dpt.finfo(Y.dtype).resolution,
316+
np.finfo(expected_Y.dtype).resolution,
317+
)
318+
np.testing.assert_allclose(
319+
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
320+
)
313321

314322

315323
@pytest.mark.parametrize("dtype", _all_dtypes)

0 commit comments

Comments
 (0)