Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478)
* Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500)
* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509)

### Changed

Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/extensions/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down
22 changes: 16 additions & 6 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "dotu.hpp"
#include "gemm.hpp"
#include "gemv.hpp"
#include "syrk.hpp"

namespace blas_ns = dpnp::extensions::blas;
namespace py = pybind11;
Expand All @@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void)
blas_ns::init_gemm_batch_dispatch_table();
blas_ns::init_gemm_dispatch_table();
blas_ns::init_gemv_dispatch_vector();
blas_ns::init_syrk_dispatch_vector();
}

static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
Expand All @@ -73,7 +75,7 @@ PYBIND11_MODULE(_blas_impl, m)
};

m.def("_dot", dot_pyapi,
"Call `dot` from OneMKL BLAS library to compute "
"Call `dot` from oneMKL BLAS library to compute "
"the dot product of two real-valued vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
py::arg("result"), py::arg("depends") = py::list());
Expand All @@ -91,7 +93,7 @@ PYBIND11_MODULE(_blas_impl, m)
};

m.def("_dotc", dotc_pyapi,
"Call `dotc` from OneMKL BLAS library to compute "
"Call `dotc` from oneMKL BLAS library to compute "
"the dot product of two complex vectors, "
"conjugating the first vector.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -110,37 +112,45 @@ PYBIND11_MODULE(_blas_impl, m)
};

m.def("_dotu", dotu_pyapi,
"Call `dotu` from OneMKL BLAS library to compute "
"Call `dotu` from oneMKL BLAS library to compute "
"the dot product of two complex vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
py::arg("result"), py::arg("depends") = py::list());
}

{
m.def("_gemm", &blas_ns::gemm,
"Call `gemm` from OneMKL BLAS library to compute "
"Call `gemm` from oneMKL BLAS library to compute "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ns::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to compute "
"Call `gemm_batch` from oneMKL BLAS library to compute "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemv", &blas_ns::gemv,
"Call `gemv` from OneMKL BLAS library to compute "
"Call `gemv` from oneMKL BLAS library to compute "
"the matrix-vector product with a general matrix.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
py::arg("vectorY"), py::arg("transpose"),
py::arg("depends") = py::list());
}

{
m.def("_syrk", &blas_ns::syrk,
"Call `syrk` from oneMKL BLAS library to compute "
"the matrix-vector product with a general matrix.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"),
py::arg("depends") = py::list());
}

{
m.def(
"_using_onemath",
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/blas/dot_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ std::pair<sycl::event, sycl::event>
dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
if (dot_fn == nullptr) {
throw py::value_error(
"Types of input vectors and result array are mismatched.");
"No dot implementation is available for the specified data type "
"of the input and output arrays.");
}

char *x_typeless_ptr = vectorX.get_data();
Expand Down
11 changes: 7 additions & 4 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
// stride between successive rows (for row major layout).
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
Tab(0), // Scaling factor for matrix C.
Expand Down Expand Up @@ -168,7 +167,8 @@ std::tuple<sycl::event, sycl::event, bool>
const int resultC_nd = resultC.get_ndim();

if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) {
throw py::value_error("Input matrices must be two-dimensional.");
throw py::value_error(
"Input and output matrices must be two-dimensional.");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
Expand Down Expand Up @@ -286,6 +286,8 @@ std::tuple<sycl::event, sycl::event, bool>
}
}
else {
// both A and B are f_contig so using column-major gemm and
// no transpose is needed
transA = oneapi::mkl::transpose::N;
transB = oneapi::mkl::transpose::N;
lda = m;
Expand Down Expand Up @@ -313,7 +315,8 @@ std::tuple<sycl::event, sycl::event, bool>
gemm_dispatch_table[matrixAB_type_id][resultC_type_id];
if (gemm_fn == nullptr) {
throw py::value_error(
"Types of input matrices and result matrix are mismatched.");
"No gemm implementation is available for the specified data type "
"of the input and output arrays.");
}

const char *a_typeless_ptr = matrixA.get_data();
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ std::tuple<sycl::event, sycl::event, bool>
gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
if (gemm_batch_fn == nullptr) {
throw py::value_error(
"Types of input matrices and result matrix are mismatched.");
"No gemm_batch implementation is available for the specified data "
"type of the input and output arrays.");
}

const char *a_typeless_ptr = matrixA.get_data();
Expand Down
49 changes: 24 additions & 25 deletions dpnp/backend/extensions/blas/gemv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
T(1), // Scaling factor for the matrix-vector product.
a, // Pointer to the input matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
// stride between successive rows (for row major layout).
x, // Pointer to the input vector x.
incx, // The stride of vector x.
T(0), // Scaling factor for vector y.
Expand Down Expand Up @@ -190,6 +189,26 @@ std::pair<sycl::event, sycl::event>
const py::ssize_t *a_shape = matrixA.get_shape_raw();
const py::ssize_t *x_shape = vectorX.get_shape_raw();
const py::ssize_t *y_shape = vectorY.get_shape_raw();
if (transpose) {
if (a_shape[0] != x_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in X.");
}
if (a_shape[1] != y_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in Y.");
}
}
else {
if (a_shape[1] != x_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in X.");
}
if (a_shape[0] != y_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in Y.");
}
}

oneapi::mkl::transpose transA;
std::size_t src_nelems;
Expand Down Expand Up @@ -243,27 +262,6 @@ std::pair<sycl::event, sycl::event>
}
#endif // USE_ONEMATH_CUBLAS

if (transpose) {
if (a_shape[0] != x_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in X.");
}
if (a_shape[1] != y_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in Y.");
}
}
else {
if (a_shape[1] != x_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in X.");
}
if (a_shape[0] != y_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in Y.");
}
}

const std::int64_t lda = is_row_major ? n : m;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY,
Expand All @@ -284,10 +282,11 @@ std::pair<sycl::event, sycl::event>
gemv_impl_fn_ptr_t gemv_fn = gemv_dispatch_vector[type_id];
if (gemv_fn == nullptr) {
throw py::value_error(
"Types of input arrays and result array are mismatched.");
"No gemv implementation is available for the specified data type "
"of the input and output arrays.");
}

char *a_typeless_ptr = matrixA.get_data();
const char *a_typeless_ptr = matrixA.get_data();
char *x_typeless_ptr = vectorX.get_data();
char *y_typeless_ptr = vectorY.get_data();

Expand Down
1 change: 0 additions & 1 deletion dpnp/backend/extensions/blas/gemv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,4 @@ extern std::pair<sycl::event, sycl::event>
const std::vector<sycl::event> &depends);

extern void init_gemv_dispatch_vector(void);
extern void init_gemv_batch_dispatch_vector(void);
} // namespace dpnp::extensions::blas
Loading
Loading