Skip to content

Commit 6a18996

Browse files
committed
Factor out repeated code in as_c_contiguous_array_generic_impl
Also only enforce alignment on dst pointer
1 parent cdf8176 commit 6a18996

File tree

1 file changed

+56
-70
lines changed

1 file changed

+56
-70
lines changed

dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp

Lines changed: 56 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,55 @@ class CopyAsCContigFunctor
120120
}
121121
};
122122

123+
template <typename T,
124+
typename IndexerT,
125+
std::uint32_t n_vecs,
126+
std::uint32_t vec_sz,
127+
bool enable_sg_load,
128+
typename KernelName>
129+
sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
130+
size_t nelems,
131+
const T *src,
132+
T *dst,
133+
const IndexerT &src_indexer,
134+
const std::vector<sycl::event> &depends)
135+
{
136+
constexpr std::size_t preferred_lws = 256;
137+
138+
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
139+
140+
auto const &ctx = exec_q.get_context();
141+
auto const &dev = exec_q.get_device();
142+
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
143+
ctx, {dev}, {kernel_id});
144+
145+
auto krn = kb.get_kernel(kernel_id);
146+
147+
const std::uint32_t max_sg_size = krn.template get_info<
148+
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
149+
150+
const std::size_t lws =
151+
((preferred_lws + max_sg_size - 1) / max_sg_size) * max_sg_size;
152+
153+
constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
154+
size_t n_groups =
155+
(nelems + nelems_per_wi * lws - 1) / (nelems_per_wi * lws);
156+
157+
sycl::event copy_ev = exec_q.submit([&](sycl::handler &cgh) {
158+
cgh.depends_on(depends);
159+
cgh.use_kernel_bundle(kb);
160+
161+
const sycl::range<1> gRange{n_groups * lws};
162+
const sycl::range<1> lRange{lws};
163+
164+
cgh.parallel_for<KernelName>(
165+
sycl::nd_range<1>(gRange, lRange),
166+
CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs, enable_sg_load>(
167+
nelems, src, dst, src_indexer));
168+
});
169+
return copy_ev;
170+
}
171+
123172
template <typename T,
124173
typename IndexT,
125174
int vec_sz,
@@ -145,7 +194,6 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
145194
using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
146195
const IndexerT src_indexer(nd, ssize_t(0), shape_and_strides);
147196

148-
constexpr std::size_t preferred_lws = 256;
149197
constexpr std::uint32_t n_vecs = 2;
150198
constexpr std::uint32_t vec_sz = 4;
151199

@@ -155,84 +203,22 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
155203
using dpctl::tensor::kernels::alignment_utils::required_alignment;
156204

157205
sycl::event copy_ev;
158-
if (is_aligned<required_alignment>(src_p) &&
159-
is_aligned<required_alignment>(dst_p))
160-
{
206+
if (is_aligned<required_alignment>(dst_p)) {
161207
constexpr bool enable_sg_load = true;
162208
using KernelName =
163209
as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
164-
165-
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
166-
167-
auto const &ctx = exec_q.get_context();
168-
auto const &dev = exec_q.get_device();
169-
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
170-
ctx, {dev}, {kernel_id});
171-
172-
auto krn = kb.get_kernel(kernel_id);
173-
174-
const std::uint32_t max_sg_size = krn.template get_info<
175-
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
176-
177-
const std::size_t lws =
178-
((preferred_lws + max_sg_size - 1) / max_sg_size) * max_sg_size;
179-
180-
constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
181-
size_t n_groups =
182-
(nelems + nelems_per_wi * lws - 1) / (nelems_per_wi * lws);
183-
184-
copy_ev = exec_q.submit([&](sycl::handler &cgh) {
185-
cgh.depends_on(depends);
186-
cgh.use_kernel_bundle(kb);
187-
188-
const sycl::range<1> gRange{n_groups * lws};
189-
const sycl::range<1> lRange{lws};
190-
191-
cgh.parallel_for<KernelName>(
192-
sycl::nd_range<1>(gRange, lRange),
193-
CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs,
194-
enable_sg_load>(nelems, src_tp, dst_tp,
195-
src_indexer));
196-
});
210+
copy_ev = submit_c_contiguous_copy<T, IndexerT, n_vecs, vec_sz,
211+
enable_sg_load, KernelName>(
212+
exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
197213
}
198214
else {
199215
constexpr bool disable_sg_load = false;
200216
using InnerKernelName =
201217
as_contig_krn<T, IndexerT, vec_sz, n_vecs, disable_sg_load>;
202218
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
203-
204-
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
205-
206-
auto const &ctx = exec_q.get_context();
207-
auto const &dev = exec_q.get_device();
208-
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
209-
ctx, {dev}, {kernel_id});
210-
211-
auto krn = kb.get_kernel(kernel_id);
212-
213-
const std::uint32_t max_sg_size = krn.template get_info<
214-
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
215-
216-
const std::size_t lws =
217-
((preferred_lws + max_sg_size - 1) / max_sg_size) * max_sg_size;
218-
219-
constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
220-
size_t n_groups =
221-
(nelems + nelems_per_wi * lws - 1) / (nelems_per_wi * lws);
222-
223-
copy_ev = exec_q.submit([&](sycl::handler &cgh) {
224-
cgh.depends_on(depends);
225-
cgh.use_kernel_bundle(kb);
226-
227-
const sycl::range<1> gRange{n_groups * lws};
228-
const sycl::range<1> lRange{lws};
229-
230-
cgh.parallel_for<KernelName>(
231-
sycl::nd_range<1>(gRange, lRange),
232-
CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs,
233-
disable_sg_load>(nelems, src_tp, dst_tp,
234-
src_indexer));
235-
});
219+
copy_ev = submit_c_contiguous_copy<T, IndexerT, n_vecs, vec_sz,
220+
disable_sg_load, KernelName>(
221+
exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
236222
}
237223

238224
return copy_ev;

0 commit comments

Comments
 (0)