Skip to content

Commit 206eef0

Browse files
Moved linear seq. functions to dedicated files
1 parent 0e8cf49 commit 206eef0

File tree

4 files changed

+242
-118
lines changed

4 files changed

+242
-118
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pybind11_add_module(${python_module_name} MODULE
2222
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
2323
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
2424
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
25+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
2526
)
2627
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
2728
target_include_directories(${python_module_name}
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#include "dpctl4pybind11.hpp"
26+
#include <CL/sycl.hpp>
27+
#include <complex>
28+
#include <pybind11/complex.h>
29+
#include <pybind11/pybind11.h>
30+
#include <utility>
31+
#include <vector>
32+
33+
#include "kernels/constructors.hpp"
34+
#include "utils/strided_iters.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
38+
#include "linear_sequences.hpp"
39+
40+
namespace py = pybind11;
41+
namespace _ns = dpctl::tensor::detail;
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace py_internal
48+
{
49+
50+
using dpctl::utils::keep_args_alive;
51+
52+
using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t;
53+
54+
static lin_space_step_fn_ptr_t lin_space_step_dispatch_vector[_ns::num_types];
55+
56+
using dpctl::tensor::kernels::constructors::lin_space_affine_fn_ptr_t;
57+
58+
static lin_space_affine_fn_ptr_t
59+
lin_space_affine_dispatch_vector[_ns::num_types];
60+
61+
std::pair<sycl::event, sycl::event>
62+
usm_ndarray_linear_sequence_step(py::object start,
63+
py::object dt,
64+
dpctl::tensor::usm_ndarray dst,
65+
sycl::queue exec_q,
66+
const std::vector<sycl::event> &depends)
67+
{
68+
// dst must be 1D and C-contiguous
69+
// start, end should be coercible into data type of dst
70+
71+
if (dst.get_ndim() != 1) {
72+
throw py::value_error(
73+
"usm_ndarray_linspace: Expecting 1D array to populate");
74+
}
75+
76+
if (!dst.is_c_contiguous()) {
77+
throw py::value_error(
78+
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
79+
}
80+
81+
sycl::queue dst_q = dst.get_queue();
82+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
83+
throw py::value_error(
84+
"Execution queue is not compatible with the allocation queue");
85+
}
86+
87+
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
88+
int dst_typenum = dst.get_typenum();
89+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
90+
91+
py::ssize_t len = dst.get_shape(0);
92+
if (len == 0) {
93+
// nothing to do
94+
return std::make_pair(sycl::event{}, sycl::event{});
95+
}
96+
97+
char *dst_data = dst.get_data();
98+
sycl::event linspace_step_event;
99+
100+
auto fn = lin_space_step_dispatch_vector[dst_typeid];
101+
102+
linspace_step_event =
103+
fn(exec_q, static_cast<size_t>(len), start, dt, dst_data, depends);
104+
105+
return std::make_pair(keep_args_alive(exec_q, {dst}, {linspace_step_event}),
106+
linspace_step_event);
107+
}
108+
109+
std::pair<sycl::event, sycl::event>
110+
usm_ndarray_linear_sequence_affine(py::object start,
111+
py::object end,
112+
dpctl::tensor::usm_ndarray dst,
113+
bool include_endpoint,
114+
sycl::queue exec_q,
115+
const std::vector<sycl::event> &depends)
116+
{
117+
// dst must be 1D and C-contiguous
118+
// start, end should be coercible into data type of dst
119+
120+
if (dst.get_ndim() != 1) {
121+
throw py::value_error(
122+
"usm_ndarray_linspace: Expecting 1D array to populate");
123+
}
124+
125+
if (!dst.is_c_contiguous()) {
126+
throw py::value_error(
127+
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
128+
}
129+
130+
sycl::queue dst_q = dst.get_queue();
131+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
132+
throw py::value_error(
133+
"Execution queue context is not the same as allocation context");
134+
}
135+
136+
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
137+
int dst_typenum = dst.get_typenum();
138+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
139+
140+
py::ssize_t len = dst.get_shape(0);
141+
if (len == 0) {
142+
// nothing to do
143+
return std::make_pair(sycl::event{}, sycl::event{});
144+
}
145+
146+
char *dst_data = dst.get_data();
147+
sycl::event linspace_affine_event;
148+
149+
auto fn = lin_space_affine_dispatch_vector[dst_typeid];
150+
151+
linspace_affine_event = fn(exec_q, static_cast<size_t>(len), start, end,
152+
include_endpoint, dst_data, depends);
153+
154+
return std::make_pair(
155+
keep_args_alive(exec_q, {dst}, {linspace_affine_event}),
156+
linspace_affine_event);
157+
}
158+
159+
void init_linear_sequences_dispatch_vectors(void)
160+
{
161+
using namespace dpctl::tensor::detail;
162+
using dpctl::tensor::kernels::constructors::LinSpaceAffineFactory;
163+
using dpctl::tensor::kernels::constructors::LinSpaceStepFactory;
164+
165+
DispatchVectorBuilder<lin_space_step_fn_ptr_t, LinSpaceStepFactory,
166+
num_types>
167+
dvb1;
168+
dvb1.populate_dispatch_vector(lin_space_step_dispatch_vector);
169+
170+
DispatchVectorBuilder<lin_space_affine_fn_ptr_t, LinSpaceAffineFactory,
171+
num_types>
172+
dvb2;
173+
dvb2.populate_dispatch_vector(lin_space_affine_dispatch_vector);
174+
}
175+
176+
} // namespace py_internal
177+
} // namespace tensor
178+
} // namespace dpctl
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <utility>
28+
#include <vector>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::pair<sycl::event, sycl::event>
41+
usm_ndarray_linear_sequence_step(py::object start,
42+
py::object dt,
43+
dpctl::tensor::usm_ndarray dst,
44+
sycl::queue exec_q,
45+
const std::vector<sycl::event> &depends = {});
46+
47+
extern std::pair<sycl::event, sycl::event> usm_ndarray_linear_sequence_affine(
48+
py::object start,
49+
py::object end,
50+
dpctl::tensor::usm_ndarray dst,
51+
bool include_endpoint,
52+
sycl::queue exec_q,
53+
const std::vector<sycl::event> &depends = {});
54+
55+
extern void init_linear_sequences_dispatch_vectors(void);
56+
57+
} // namespace py_internal
58+
} // namespace tensor
59+
} // namespace dpctl

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 4 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "copy_and_cast_usm_to_usm.hpp"
4343
#include "copy_for_reshape.hpp"
4444
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
45+
#include "linear_sequences.hpp"
4546
#include "simplify_iteration_space.hpp"
4647

4748
namespace py = pybind11;
@@ -68,112 +69,8 @@ using dpctl::tensor::py_internal::copy_numpy_ndarray_into_usm_ndarray;
6869

6970
/* ============= linear-sequence ==================== */
7071

71-
using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t;
72-
73-
static lin_space_step_fn_ptr_t lin_space_step_dispatch_vector[_ns::num_types];
74-
75-
using dpctl::tensor::kernels::constructors::lin_space_affine_fn_ptr_t;
76-
77-
static lin_space_affine_fn_ptr_t
78-
lin_space_affine_dispatch_vector[_ns::num_types];
79-
80-
std::pair<sycl::event, sycl::event>
81-
usm_ndarray_linear_sequence_step(py::object start,
82-
py::object dt,
83-
dpctl::tensor::usm_ndarray dst,
84-
sycl::queue exec_q,
85-
const std::vector<sycl::event> &depends = {})
86-
{
87-
// dst must be 1D and C-contiguous
88-
// start, end should be coercible into data type of dst
89-
90-
if (dst.get_ndim() != 1) {
91-
throw py::value_error(
92-
"usm_ndarray_linspace: Expecting 1D array to populate");
93-
}
94-
95-
if (!dst.is_c_contiguous()) {
96-
throw py::value_error(
97-
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
98-
}
99-
100-
sycl::queue dst_q = dst.get_queue();
101-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
102-
throw py::value_error(
103-
"Execution queue is not compatible with the allocation queue");
104-
}
105-
106-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
107-
int dst_typenum = dst.get_typenum();
108-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
109-
110-
py::ssize_t len = dst.get_shape(0);
111-
if (len == 0) {
112-
// nothing to do
113-
return std::make_pair(sycl::event{}, sycl::event{});
114-
}
115-
116-
char *dst_data = dst.get_data();
117-
sycl::event linspace_step_event;
118-
119-
auto fn = lin_space_step_dispatch_vector[dst_typeid];
120-
121-
linspace_step_event =
122-
fn(exec_q, static_cast<size_t>(len), start, dt, dst_data, depends);
123-
124-
return std::make_pair(keep_args_alive(exec_q, {dst}, {linspace_step_event}),
125-
linspace_step_event);
126-
}
127-
128-
std::pair<sycl::event, sycl::event>
129-
usm_ndarray_linear_sequence_affine(py::object start,
130-
py::object end,
131-
dpctl::tensor::usm_ndarray dst,
132-
bool include_endpoint,
133-
sycl::queue exec_q,
134-
const std::vector<sycl::event> &depends = {})
135-
{
136-
// dst must be 1D and C-contiguous
137-
// start, end should be coercible into data type of dst
138-
139-
if (dst.get_ndim() != 1) {
140-
throw py::value_error(
141-
"usm_ndarray_linspace: Expecting 1D array to populate");
142-
}
143-
144-
if (!dst.is_c_contiguous()) {
145-
throw py::value_error(
146-
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
147-
}
148-
149-
sycl::queue dst_q = dst.get_queue();
150-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
151-
throw py::value_error(
152-
"Execution queue context is not the same as allocation context");
153-
}
154-
155-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
156-
int dst_typenum = dst.get_typenum();
157-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
158-
159-
py::ssize_t len = dst.get_shape(0);
160-
if (len == 0) {
161-
// nothing to do
162-
return std::make_pair(sycl::event{}, sycl::event{});
163-
}
164-
165-
char *dst_data = dst.get_data();
166-
sycl::event linspace_affine_event;
167-
168-
auto fn = lin_space_affine_dispatch_vector[dst_typeid];
169-
170-
linspace_affine_event = fn(exec_q, static_cast<size_t>(len), start, end,
171-
include_endpoint, dst_data, depends);
172-
173-
return std::make_pair(
174-
keep_args_alive(exec_q, {dst}, {linspace_affine_event}),
175-
linspace_affine_event);
176-
}
72+
using dpctl::tensor::py_internal::usm_ndarray_linear_sequence_affine;
73+
using dpctl::tensor::py_internal::usm_ndarray_linear_sequence_step;
17774

17875
/* ================ Full ================== */
17976

@@ -537,25 +434,14 @@ void init_dispatch_tables(void)
537434
void init_dispatch_vectors(void)
538435
{
539436
dpctl::tensor::py_internal::init_copy_for_reshape_dispatch_vectors();
437+
dpctl::tensor::py_internal::init_linear_sequences_dispatch_vectors();
540438

541439
using namespace dpctl::tensor::detail;
542440
using dpctl::tensor::kernels::constructors::EyeFactory;
543441
using dpctl::tensor::kernels::constructors::FullContigFactory;
544-
using dpctl::tensor::kernels::constructors::LinSpaceAffineFactory;
545-
using dpctl::tensor::kernels::constructors::LinSpaceStepFactory;
546442
using dpctl::tensor::kernels::constructors::TrilGenericFactory;
547443
using dpctl::tensor::kernels::constructors::TriuGenericFactory;
548444

549-
DispatchVectorBuilder<lin_space_step_fn_ptr_t, LinSpaceStepFactory,
550-
num_types>
551-
dvb1;
552-
dvb1.populate_dispatch_vector(lin_space_step_dispatch_vector);
553-
554-
DispatchVectorBuilder<lin_space_affine_fn_ptr_t, LinSpaceAffineFactory,
555-
num_types>
556-
dvb2;
557-
dvb2.populate_dispatch_vector(lin_space_affine_dispatch_vector);
558-
559445
DispatchVectorBuilder<full_contig_fn_ptr_t, FullContigFactory, num_types>
560446
dvb3;
561447
dvb3.populate_dispatch_vector(full_contig_dispatch_vector);

0 commit comments

Comments
 (0)