Skip to content

Commit efa5c45

Browse files
committed
Update entire ufunc extension to use init_dispatch_vector() and using init_dispatch_table() from common utils
1 parent e4440ad commit efa5c45

File tree

4 files changed

+26
-31
lines changed

4 files changed

+26
-31
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/float_power.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434
#include "float_power.hpp"
3535

36+
// utils extension header
37+
#include "ext/common.hpp"
38+
3639
// include a local copy of elementwise common header from dpctl tensor:
3740
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
3841
// TODO: replace by including dpctl header once available
@@ -50,6 +53,8 @@ namespace impl
5053
{
5154
namespace td_ns = dpctl::tensor::type_dispatch;
5255

56+
using ext::common::init_dispatch_table;
57+
5358
// Supports only float and complex types
5459
template <typename T1, typename T2>
5560
struct OutputType
@@ -82,10 +87,9 @@ struct TypeMapFactory
8287
}
8388
};
8489

85-
void populate_float_power_dispatch_tables(void)
90+
static void populate_float_power_dispatch_tables(void)
8691
{
87-
td_ns::DispatchTableBuilder<int, TypeMapFactory, td_ns::num_types> dvb;
88-
dvb.populate_dispatch_table(float_power_output_typeid_table);
92+
init_dispatch_table<int, TypeMapFactory>(float_power_output_typeid_table);
8993
}
9094
} // namespace impl
9195

dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
#include "kernels/elementwise_functions/interpolate.hpp"
4646

47+
// utils extension headers
4748
#include "ext/common.hpp"
4849
#include "ext/validation_utils.hpp"
4950

@@ -67,6 +68,7 @@ namespace dpnp::extensions::ufunc
6768

6869
namespace impl
6970
{
71+
using ext::common::init_dispatch_vector;
7072

7173
template <typename T>
7274
using value_type_of_t = typename value_type_of<T>::type;
@@ -242,13 +244,10 @@ struct InterpolateFactory
242244
}
243245
};
244246

245-
void init_interpolate_dispatch_vectors()
247+
static void init_interpolate_dispatch_vectors()
246248
{
247-
using namespace td_ns;
248-
249-
DispatchVectorBuilder<interpolate_fn_ptr_t, InterpolateFactory, num_types>
250-
dtb_interpolate;
251-
dtb_interpolate.populate_dispatch_vector(interpolate_dispatch_vector);
249+
init_dispatch_vector<interpolate_fn_ptr_t, InterpolateFactory>(
250+
interpolate_dispatch_vector);
252251
}
253252

254253
} // namespace impl

dpnp/backend/extensions/ufunc/elementwise_functions/isclose.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "utils/type_dispatch.hpp"
4646
#include "utils/type_utils.hpp"
4747

48+
// utils extension headers
4849
#include "ext/common.hpp"
4950
#include "ext/validation_utils.hpp"
5051

@@ -68,6 +69,7 @@ namespace dpnp::extensions::ufunc
6869

6970
namespace impl
7071
{
72+
using ext::common::init_dispatch_vector;
7173

7274
typedef sycl::event (*isclose_strided_scalar_fn_ptr_t)(
7375
sycl::queue &,
@@ -351,19 +353,14 @@ struct IsCloseContigScalarFactory
351353
}
352354
};
353355

354-
void populate_isclose_dispatch_vectors()
356+
static void populate_isclose_dispatch_vectors()
355357
{
356-
using namespace td_ns;
357-
358-
DispatchVectorBuilder<isclose_strided_scalar_fn_ptr_t,
359-
IsCloseStridedScalarFactory, num_types>
360-
dvb1;
361-
dvb1.populate_dispatch_vector(isclose_strided_scalar_dispatch_vector);
362-
363-
DispatchVectorBuilder<isclose_contig_scalar_fn_ptr_t,
364-
IsCloseContigScalarFactory, num_types>
365-
dvb2;
366-
dvb2.populate_dispatch_vector(isclose_contig_dispatch_vector);
358+
init_dispatch_vector<isclose_strided_scalar_fn_ptr_t,
359+
IsCloseStridedScalarFactory>(
360+
isclose_strided_scalar_dispatch_vector);
361+
init_dispatch_vector<isclose_contig_scalar_fn_ptr_t,
362+
IsCloseContigScalarFactory>(
363+
isclose_contig_dispatch_vector);
367364
}
368365

369366
} // namespace impl

dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "utils/type_dispatch.hpp"
5353
#include "utils/type_utils.hpp"
5454

55+
// utils extension header
5556
#include "ext/common.hpp"
5657

5758
namespace py = pybind11;
@@ -65,6 +66,7 @@ namespace dpnp::extensions::ufunc
6566

6667
namespace impl
6768
{
69+
using ext::common::init_dispatch_vector;
6870

6971
template <typename T>
7072
using value_type_of_t = typename value_type_of<T>::type;
@@ -370,17 +372,10 @@ struct NanToNumContigFactory
370372
}
371373
};
372374

373-
void populate_nan_to_num_dispatch_vectors(void)
375+
static void populate_nan_to_num_dispatch_vectors(void)
374376
{
375-
using namespace td_ns;
376-
377-
DispatchVectorBuilder<nan_to_num_fn_ptr_t, NanToNumFactory, num_types> dvb1;
378-
dvb1.populate_dispatch_vector(nan_to_num_dispatch_vector);
379-
380-
DispatchVectorBuilder<nan_to_num_contig_fn_ptr_t, NanToNumContigFactory,
381-
num_types>
382-
dvb2;
383-
dvb2.populate_dispatch_vector(nan_to_num_contig_dispatch_vector);
377+
init_dispatch_vector<nan_to_num_fn_ptr_t, NanToNumFactory>(
378+
nan_to_num_dispatch_vector);
384379
}
385380

386381
} // namespace impl

0 commit comments

Comments
 (0)