Skip to content

Commit 8906b0e

Browse files
authored
Added implementation of dpnp.special.erfc (#2588)
The PR add the implementation of `dpnp.special.erfc` function, including tests coverage. Also it reworks the previous implementation of erf function to reduce future code duplication. because most of erf-like function has the same types supported and similar signature.
1 parent 99f96cf commit 8906b0e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+856
-708
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
* Added implementation of `dpnp.linalg.lu_factor` (SciPy-compatible) [#2557](https://github.com/IntelPython/dpnp/pull/2557), [#2565](https://github.com/IntelPython/dpnp/pull/2565)
1919
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)
2020
* Added implementation of `dpnp.linalg.lu_solve` for 2D inputs (SciPy-compatible) [#2575](https://github.com/IntelPython/dpnp/pull/2575)
21+
* Added implementation of `dpnp.special.erfc` [#2588](https://github.com/IntelPython/dpnp/pull/2588)
2122

2223
### Changed
2324

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ endif()
6060

6161
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
6262

63-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
64-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
63+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
6564
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
6665

6766
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
#include <pybind11/pybind11.h>
3131
#include <pybind11/stl.h>
3232

33+
// utils extension header
34+
#include "ext/common.hpp"
35+
3336
#include "dot.hpp"
3437
#include "dot_common.hpp"
3538
#include "dotc.hpp"
@@ -41,7 +44,9 @@
4144
namespace blas_ns = dpnp::extensions::blas;
4245
namespace py = pybind11;
4346
namespace dot_ns = blas_ns::dot;
47+
4448
using dot_ns::dot_impl_fn_ptr_t;
49+
using ext::common::init_dispatch_vector;
4550

4651
// populate dispatch vectors and tables
4752
void init_dispatch_vectors_tables(void)
@@ -64,7 +69,7 @@ PYBIND11_MODULE(_blas_impl, m)
6469
using event_vecT = std::vector<sycl::event>;
6570

6671
{
67-
dot_ns::init_dot_dispatch_vector<blas_ns::DotContigFactory>(
72+
init_dispatch_vector<dot_impl_fn_ptr_t, blas_ns::DotContigFactory>(
6873
dot_dispatch_vector);
6974

7075
auto dot_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
@@ -82,7 +87,7 @@ PYBIND11_MODULE(_blas_impl, m)
8287
}
8388

8489
{
85-
dot_ns::init_dot_dispatch_vector<blas_ns::DotcContigFactory>(
90+
init_dispatch_vector<dot_impl_fn_ptr_t, blas_ns::DotcContigFactory>(
8691
dotc_dispatch_vector);
8792

8893
auto dotc_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
@@ -101,7 +106,7 @@ PYBIND11_MODULE(_blas_impl, m)
101106
}
102107

103108
{
104-
dot_ns::init_dot_dispatch_vector<blas_ns::DotuContigFactory>(
109+
init_dispatch_vector<dot_impl_fn_ptr_t, blas_ns::DotuContigFactory>(
105110
dotu_dispatch_vector);
106111

107112
auto dotu_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,

dpnp/backend/extensions/blas/dot_common.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,4 @@ std::pair<sycl::event, sycl::event>
165165

166166
return std::make_pair(args_ev, dot_ev);
167167
}
168-
169-
template <template <typename fnT, typename T> typename factoryT>
170-
void init_dot_dispatch_vector(dot_impl_fn_ptr_t dot_dispatch_vector[])
171-
{
172-
dpctl_td_ns::DispatchVectorBuilder<dot_impl_fn_ptr_t, factoryT,
173-
dpctl_td_ns::num_types>
174-
contig;
175-
contig.populate_dispatch_vector(dot_dispatch_vector);
176-
}
177168
} // namespace dpnp::extensions::blas::dot

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/memory_overlap.hpp"
3235
#include "utils/output_validation.hpp"
@@ -35,14 +38,14 @@
3538
#include "gemm.hpp"
3639
#include "types_matrix.hpp"
3740

38-
#include "dpnp_utils.hpp"
39-
4041
namespace dpnp::extensions::blas
4142
{
4243
namespace mkl_blas = oneapi::mkl::blas;
4344
namespace py = pybind11;
4445
namespace type_utils = dpctl::tensor::type_utils;
4546

47+
using ext::common::init_dispatch_table;
48+
4649
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
4750
oneapi::mkl::transpose,
4851
oneapi::mkl::transpose,
@@ -339,9 +342,7 @@ struct GemmContigFactory
339342

340343
void init_gemm_dispatch_table(void)
341344
{
342-
dpctl_td_ns::DispatchTableBuilder<gemm_impl_fn_ptr_t, GemmContigFactory,
343-
dpctl_td_ns::num_types>
344-
contig;
345-
contig.populate_dispatch_table(gemm_dispatch_table);
345+
init_dispatch_table<gemm_impl_fn_ptr_t, GemmContigFactory>(
346+
gemm_dispatch_table);
346347
}
347348
} // namespace dpnp::extensions::blas

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/memory_overlap.hpp"
3235
#include "utils/output_validation.hpp"
@@ -35,14 +38,14 @@
3538
#include "gemm.hpp"
3639
#include "types_matrix.hpp"
3740

38-
#include "dpnp_utils.hpp"
39-
4041
namespace dpnp::extensions::blas
4142
{
4243
namespace mkl_blas = oneapi::mkl::blas;
4344
namespace py = pybind11;
4445
namespace type_utils = dpctl::tensor::type_utils;
4546

47+
using ext::common::init_dispatch_table;
48+
4649
typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
4750
sycl::queue &,
4851
const std::int64_t,
@@ -415,10 +418,7 @@ struct GemmBatchContigFactory
415418

416419
void init_gemm_batch_dispatch_table(void)
417420
{
418-
dpctl_td_ns::DispatchTableBuilder<gemm_batch_impl_fn_ptr_t,
419-
GemmBatchContigFactory,
420-
dpctl_td_ns::num_types>
421-
contig;
422-
contig.populate_dispatch_table(gemm_batch_dispatch_table);
421+
init_dispatch_table<gemm_batch_impl_fn_ptr_t, GemmBatchContigFactory>(
422+
gemm_batch_dispatch_table);
423423
}
424424
} // namespace dpnp::extensions::blas

dpnp/backend/extensions/blas/gemv.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/memory_overlap.hpp"
3235
#include "utils/output_validation.hpp"
@@ -35,14 +38,14 @@
3538
#include "gemv.hpp"
3639
#include "types_matrix.hpp"
3740

38-
#include "dpnp_utils.hpp"
39-
4041
namespace dpnp::extensions::blas
4142
{
4243
namespace mkl_blas = oneapi::mkl::blas;
4344
namespace py = pybind11;
4445
namespace type_utils = dpctl::tensor::type_utils;
4546

47+
using ext::common::init_dispatch_vector;
48+
4649
typedef sycl::event (*gemv_impl_fn_ptr_t)(sycl::queue &,
4750
oneapi::mkl::transpose,
4851
const std::int64_t,
@@ -320,9 +323,7 @@ struct GemvContigFactory
320323

321324
void init_gemv_dispatch_vector(void)
322325
{
323-
dpctl_td_ns::DispatchVectorBuilder<gemv_impl_fn_ptr_t, GemvContigFactory,
324-
dpctl_td_ns::num_types>
325-
contig;
326-
contig.populate_dispatch_vector(gemv_dispatch_vector);
326+
init_dispatch_vector<gemv_impl_fn_ptr_t, GemvContigFactory>(
327+
gemv_dispatch_vector);
327328
}
328329
} // namespace dpnp::extensions::blas

dpnp/backend/extensions/blas/syrk.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <pybind11/pybind11.h>
3030

31+
// utils extension header
3132
#include "ext/common.hpp"
3233

3334
// dpctl tensor headers
@@ -38,8 +39,6 @@
3839
#include "syrk.hpp"
3940
#include "types_matrix.hpp"
4041

41-
#include "dpnp_utils.hpp"
42-
4342
using ext::common::Align;
4443

4544
namespace dpnp::extensions::blas
@@ -48,6 +47,8 @@ namespace mkl_blas = oneapi::mkl::blas;
4847
namespace py = pybind11;
4948
namespace type_utils = dpctl::tensor::type_utils;
5049

50+
using ext::common::init_dispatch_vector;
51+
5152
typedef sycl::event (*syrk_impl_fn_ptr_t)(sycl::queue &,
5253
const oneapi::mkl::transpose,
5354
const std::int64_t,
@@ -349,9 +350,7 @@ struct SyrkContigFactory
349350

350351
void init_syrk_dispatch_vector(void)
351352
{
352-
dpctl_td_ns::DispatchVectorBuilder<syrk_impl_fn_ptr_t, SyrkContigFactory,
353-
dpctl_td_ns::num_types>
354-
contig;
355-
contig.populate_dispatch_vector(syrk_dispatch_vector);
353+
init_dispatch_vector<syrk_impl_fn_ptr_t, SyrkContigFactory>(
354+
syrk_dispatch_vector);
356355
}
357356
} // namespace dpnp::extensions::blas

dpnp/backend/extensions/common/ext/common.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@
3030
#include <pybind11/pybind11.h>
3131
#include <sycl/sycl.hpp>
3232

33+
// dpctl tensor headers
3334
#include "utils/math_utils.hpp"
35+
#include "utils/type_dispatch.hpp"
3436
#include "utils/type_utils.hpp"
3537

3638
namespace type_utils = dpctl::tensor::type_utils;
39+
namespace type_dispatch = dpctl::tensor::type_dispatch;
3740

3841
namespace ext::common
3942
{
@@ -206,6 +209,25 @@ sycl::nd_range<1>
206209
// headers of dpctl.
207210
pybind11::dtype dtype_from_typenum(int dst_typenum);
208211

212+
template <typename dispatchT,
213+
template <typename fnT, typename T>
214+
typename factoryT,
215+
int _num_types = type_dispatch::num_types>
216+
inline void init_dispatch_vector(dispatchT dispatch_vector[])
217+
{
218+
type_dispatch::DispatchVectorBuilder<dispatchT, factoryT, _num_types> dvb;
219+
dvb.populate_dispatch_vector(dispatch_vector);
220+
}
221+
222+
template <typename dispatchT,
223+
template <typename fnT, typename D, typename S>
224+
typename factoryT,
225+
int _num_types = type_dispatch::num_types>
226+
inline void init_dispatch_table(dispatchT dispatch_table[][_num_types])
227+
{
228+
type_dispatch::DispatchTableBuilder<dispatchT, factoryT, _num_types> dtb;
229+
dtb.populate_dispatch_table(dispatch_table);
230+
}
209231
} // namespace ext::common
210232

211233
#include "ext/details/common_internal.hpp"

dpnp/backend/extensions/indexing/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ endif()
5757

5858
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
5959

60-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
61-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
60+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
61+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
6262

6363
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
6464
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

0 commit comments

Comments
 (0)