|
9 | 9 | #ifndef __DPCT_DNNL_UTILS_HPP__ |
10 | 10 | #define __DPCT_DNNL_UTILS_HPP__ |
11 | 11 |
|
12 | | -#include <oneapi/dpl/algorithm> |
13 | | -#include <oneapi/dpl/execution> |
14 | | -#include <oneapi/dpl/numeric> |
15 | | -#include <oneapi/mkl.hpp> |
16 | | -#include <oneapi/mkl/rng/device.hpp> |
17 | 12 | #include <sycl/sycl.hpp> |
| 13 | + |
| 14 | +#include <dpct/dpct.hpp> |
| 15 | + |
18 | 16 | #include <oneapi/dnnl/dnnl.hpp> |
19 | 17 | #include <oneapi/dnnl/dnnl_sycl.hpp> |
20 | | -#include <unordered_map> |
| 18 | +#include <oneapi/mkl.hpp> |
| 19 | +#include <oneapi/mkl/rng/device.hpp> |
| 20 | + |
21 | 21 | #include <algorithm> |
22 | 22 | #include <list> |
| 23 | +#include <unordered_map> |
23 | 24 |
|
24 | | -#include "memory.hpp" |
25 | 25 | #include "device.hpp" |
26 | 26 | #include "lib_common_utils.hpp" |
| 27 | +#include "memory.hpp" |
27 | 28 |
|
28 | 29 | namespace dpct { |
29 | 30 | namespace dnnl { |
@@ -1023,20 +1024,21 @@ class engine_ext { |
1023 | 1024 | return q->fill<T>(static_cast<T *>(src), *static_cast<const T *>(value), |
1024 | 1025 | size_with_byte / sizeof(T)); |
1025 | 1026 | } |
1026 | | - template <typename T> struct no_zero_op { |
1027 | | - T operator()(T e) { |
1028 | | - if (!e) { |
1029 | | - return 1; |
1030 | | - } |
1031 | | - return e; |
1032 | | - } |
1033 | | - }; |
1034 | 1027 | template <typename T> |
1035 | 1028 | void transform_no_zero_with_type(sycl::queue *q, void *src, void *dst, |
1036 | 1029 | size_t num) { |
1037 | | - std::transform(oneapi::dpl::execution::make_device_policy(*q), |
1038 | | - static_cast<T *>(src), static_cast<T *>(src) + num, |
1039 | | - static_cast<T *>(dst), no_zero_op<T>()); |
| 1030 | + q->submit([&](sycl::handler &cgh) { |
| 1031 | + cgh.parallel_for<dpct_kernel_name<class zero_to_one, T>>( |
| 1032 | + sycl::range<1>(num), [=](sycl::id<1> idx) { |
| 1033 | + T *src_ptr = static_cast<T *>(src) + idx[0]; |
| 1034 | + T *dst_ptr = static_cast<T *>(dst) + idx[0]; |
| 1035 | + if (*src_ptr) { |
| 1036 | + *dst_ptr = *src_ptr; |
| 1037 | + } else { |
| 1038 | + *dst_ptr = 1; |
| 1039 | + } |
| 1040 | + }); |
| 1041 | + }); |
1040 | 1042 | } |
1041 | 1043 | void transform_no_zero(const memory_desc_ext &desc, void *src, void *dst); |
1042 | 1044 | ::dnnl::memory::desc get_group_weight_desc(int group_count, |
|
0 commit comments