@@ -148,41 +148,92 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
148148 constexpr std::size_t preferred_lws = 256 ;
149149 constexpr std::uint32_t n_vecs = 2 ;
150150 constexpr std::uint32_t vec_sz = 4 ;
151- constexpr bool enable_sg_load = true ;
152- using KernelName =
153- as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
154151
155- const auto &kernel_id = sycl::get_kernel_id<KernelName>();
152+ using dpctl::tensor::kernels::alignment_utils::
153+ disabled_sg_loadstore_wrapper_krn;
154+ using dpctl::tensor::kernels::alignment_utils::is_aligned;
155+ using dpctl::tensor::kernels::alignment_utils::required_alignment;
156156
157- auto const &ctx = exec_q.get_context ();
158- auto const &dev = exec_q.get_device ();
159- auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
160- ctx, {dev}, {kernel_id});
157+ sycl::event copy_ev;
158+ if (is_aligned<required_alignment>(src_p) &&
159+ is_aligned<required_alignment>(dst_p))
160+ {
161+ constexpr bool enable_sg_load = true ;
162+ using KernelName =
163+ as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
161164
162- auto krn = kb. get_kernel (kernel_id );
165+ const auto &kernel_id = sycl::get_kernel_id<KernelName>( );
163166
164- const std::uint32_t max_sg_size = krn.template get_info <
165- sycl::info::kernel_device_specific::max_sub_group_size>(dev);
167+ auto const &ctx = exec_q.get_context ();
168+ auto const &dev = exec_q.get_device ();
169+ auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
170+ ctx, {dev}, {kernel_id});
166171
167- const std::size_t lws =
168- ((preferred_lws + max_sg_size - 1 ) / max_sg_size) * max_sg_size;
172+ auto krn = kb.get_kernel (kernel_id);
169173
170- constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
171- size_t n_groups =
172- (nelems + nelems_per_wi * lws - 1 ) / (nelems_per_wi * lws);
174+ const std::uint32_t max_sg_size = krn.template get_info <
175+ sycl::info::kernel_device_specific::max_sub_group_size>(dev);
173176
174- sycl::event copy_ev = exec_q.submit ([&](sycl::handler &cgh) {
175- cgh.depends_on (depends);
176- cgh.use_kernel_bundle (kb);
177+ const std::size_t lws =
178+ ((preferred_lws + max_sg_size - 1 ) / max_sg_size) * max_sg_size;
177179
178- const sycl::range<1 > gRange {n_groups * lws};
179- const sycl::range<1 > lRange{lws};
180+ constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
181+ size_t n_groups =
182+ (nelems + nelems_per_wi * lws - 1 ) / (nelems_per_wi * lws);
180183
181- cgh.parallel_for <KernelName>(
182- sycl::nd_range<1 >(gRange , lRange),
183- CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs, enable_sg_load>(
184- nelems, src_tp, dst_tp, src_indexer));
185- });
184+ copy_ev = exec_q.submit ([&](sycl::handler &cgh) {
185+ cgh.depends_on (depends);
186+ cgh.use_kernel_bundle (kb);
187+
188+ const sycl::range<1 > gRange {n_groups * lws};
189+ const sycl::range<1 > lRange{lws};
190+
191+ cgh.parallel_for <KernelName>(
192+ sycl::nd_range<1 >(gRange , lRange),
193+ CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs,
194+ enable_sg_load>(nelems, src_tp, dst_tp,
195+ src_indexer));
196+ });
197+ }
198+ else {
199+ constexpr bool disable_sg_load = false ;
200+ using InnerKernelName =
201+ as_contig_krn<T, IndexerT, vec_sz, n_vecs, disable_sg_load>;
202+ using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
203+
204+ const auto &kernel_id = sycl::get_kernel_id<KernelName>();
205+
206+ auto const &ctx = exec_q.get_context ();
207+ auto const &dev = exec_q.get_device ();
208+ auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
209+ ctx, {dev}, {kernel_id});
210+
211+ auto krn = kb.get_kernel (kernel_id);
212+
213+ const std::uint32_t max_sg_size = krn.template get_info <
214+ sycl::info::kernel_device_specific::max_sub_group_size>(dev);
215+
216+ const std::size_t lws =
217+ ((preferred_lws + max_sg_size - 1 ) / max_sg_size) * max_sg_size;
218+
219+ constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
220+ size_t n_groups =
221+ (nelems + nelems_per_wi * lws - 1 ) / (nelems_per_wi * lws);
222+
223+ copy_ev = exec_q.submit ([&](sycl::handler &cgh) {
224+ cgh.depends_on (depends);
225+ cgh.use_kernel_bundle (kb);
226+
227+ const sycl::range<1 > gRange {n_groups * lws};
228+ const sycl::range<1 > lRange{lws};
229+
230+ cgh.parallel_for <KernelName>(
231+ sycl::nd_range<1 >(gRange , lRange),
232+ CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs,
233+ disable_sg_load>(nelems, src_tp, dst_tp,
234+ src_indexer));
235+ });
236+ }
186237
187238 return copy_ev;
188239}
0 commit comments