Skip to content

Commit 4a498c1

Browse files
authored
add backend for full_like, ones_like, zeros_like funcs (#638)
* add backend for full_like, ones_like, zeros_like funcs
1 parent 85aad6b commit 4a498c1

File tree

6 files changed

+115
-4
lines changed

6 files changed

+115
-4
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ INP_DLLEXPORT void dpnp_arange_c(size_t start, size_t step, void* result1, size_
141141
template <typename _DataType>
142142
INP_DLLEXPORT void dpnp_full_c(void* array_in, void* result, const size_t size);
143143

144+
/**
145+
* @ingroup BACKEND_API
146+
* @brief Implementation of full_like function
147+
*
148+
* @param [in] array_in Input one-element array.
149+
* @param [out] result Output array.
150+
* @param [in] size Number of elements in the output array.
151+
*/
152+
template <typename _DataType>
153+
INP_DLLEXPORT void dpnp_full_like_c(void* array_in, void* result, size_t size);
154+
144155
/**
145156
* @ingroup BACKEND_API
146157
* @brief Matrix multiplication.
@@ -711,6 +722,16 @@ INP_DLLEXPORT void dpnp_modf_c(void* array1_in, void* result1_out, void* result2
711722
template <typename _DataType>
712723
INP_DLLEXPORT void dpnp_ones_c(void* result, size_t size);
713724

725+
/**
726+
* @ingroup BACKEND_API
727+
* @brief Implementation of ones_like function
728+
*
729+
* @param [out] result Output array.
730+
* @param [in] size Number of elements in the output array.
731+
*/
732+
template <typename _DataType>
733+
INP_DLLEXPORT void dpnp_ones_like_c(void* result, size_t size);
734+
714735
/**
715736
* @ingroup BACKEND_API
716737
* @brief remainder function.
@@ -781,4 +802,14 @@ INP_DLLEXPORT void dpnp_trapz_c(
781802
template <typename _DataType>
782803
INP_DLLEXPORT void dpnp_zeros_c(void* result, size_t size);
783804

805+
/**
806+
* @ingroup BACKEND_API
807+
* @brief Implementation of zeros_like function
808+
*
809+
* @param [out] result Output array.
810+
* @param [in] size Number of elements in the output array.
811+
*/
812+
template <typename _DataType>
813+
INP_DLLEXPORT void dpnp_zeros_like_c(void* result, size_t size);
814+
784815
#endif // BACKEND_IFACE_H

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ enum class DPNPFuncName : size_t
109109
DPNP_FN_FLOOR_DIVIDE, /**< Used in numpy.floor_divide() implementation */
110110
DPNP_FN_FMOD, /**< Used in numpy.fmod() implementation */
111111
DPNP_FN_FULL, /**< Used in numpy.full() implementation */
112+
DPNP_FN_FULL_LIKE, /**< Used in numpy.full_like() implementation */
112113
DPNP_FN_HYPOT, /**< Used in numpy.hypot() implementation */
113114
DPNP_FN_INITVAL, /**< Used in numpy ones, ones_like, zeros, zeros_like, full, full_like impl */
114115
DPNP_FN_INV, /**< Used in numpy.linalg.inv() implementation */
@@ -130,6 +131,7 @@ enum class DPNPFuncName : size_t
130131
DPNP_FN_MODF, /**< Used in numpy.modf() implementation */
131132
DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() implementation */
132133
DPNP_FN_ONES, /**< Used in numpy.ones() implementation */
134+
DPNP_FN_ONES_LIKE, /**< Used in numpy.ones_like() implementation */
133135
DPNP_FN_PLACE, /**< Used in numpy.place() implementation */
134136
DPNP_FN_POWER, /**< Used in numpy.power() implementation */
135137
DPNP_FN_PROD, /**< Used in numpy.prod() implementation */
@@ -195,6 +197,7 @@ enum class DPNPFuncName : size_t
195197
DPNP_FN_TRUNC, /**< Used in numpy.trunc() implementation */
196198
DPNP_FN_VAR, /**< Used in numpy.var() implementation */
197199
DPNP_FN_ZEROS, /**< Used in numpy.zeros() implementation */
200+
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() implementation */
198201
DPNP_FN_LAST /**< The latest element of the enumeration */
199202
};
200203

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ void dpnp_full_c(void* array_in, void* result, const size_t size)
104104
dpnp_initval_c<_DataType>(result, array_in, size);
105105
}
106106

107+
template <typename _DataType>
108+
void dpnp_full_like_c(void* array_in, void* result, const size_t size)
109+
{
110+
dpnp_full_c<_DataType>(array_in, result, size);
111+
}
112+
107113
template <typename _DataType>
108114
void dpnp_ones_c(void* result, size_t size)
109115
{
@@ -115,6 +121,12 @@ void dpnp_ones_c(void* result, size_t size)
115121
dpnp_memory_free_c(fill_value);
116122
}
117123

124+
template <typename _DataType>
125+
void dpnp_ones_like_c(void* result, size_t size)
126+
{
127+
dpnp_ones_c<_DataType>(result, size);
128+
}
129+
118130
template <typename _DataType>
119131
void dpnp_zeros_c(void* result, size_t size)
120132
{
@@ -126,6 +138,12 @@ void dpnp_zeros_c(void* result, size_t size)
126138
dpnp_memory_free_c(fill_value);
127139
}
128140

141+
template <typename _DataType>
142+
void dpnp_zeros_like_c(void* result, size_t size)
143+
{
144+
dpnp_zeros_c<_DataType>(result, size);
145+
}
146+
129147
template <typename _DataType>
130148
void dpnp_tril_c(void* array_in,
131149
void* result1,
@@ -345,6 +363,15 @@ void func_map_init_arraycreation(func_map_t& fmap)
345363
fmap[DPNPFuncName::DPNP_FN_FULL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_full_c<long>};
346364
fmap[DPNPFuncName::DPNP_FN_FULL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_full_c<float>};
347365
fmap[DPNPFuncName::DPNP_FN_FULL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_full_c<double>};
366+
fmap[DPNPFuncName::DPNP_FN_FULL][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_full_c<bool>};
367+
fmap[DPNPFuncName::DPNP_FN_FULL][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_full_c<std::complex<double>>};
368+
369+
fmap[DPNPFuncName::DPNP_FN_FULL_LIKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_full_like_c<int>};
370+
fmap[DPNPFuncName::DPNP_FN_FULL_LIKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_full_like_c<long>};
371+
fmap[DPNPFuncName::DPNP_FN_FULL_LIKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_full_like_c<float>};
372+
fmap[DPNPFuncName::DPNP_FN_FULL_LIKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_full_like_c<double>};
373+
fmap[DPNPFuncName::DPNP_FN_FULL_LIKE][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_full_like_c<bool>};
374+
fmap[DPNPFuncName::DPNP_FN_FULL_LIKE][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_full_like_c<std::complex<double>>};
348375

349376
fmap[DPNPFuncName::DPNP_FN_ONES][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_ones_c<int>};
350377
fmap[DPNPFuncName::DPNP_FN_ONES][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_ones_c<long>};
@@ -353,13 +380,27 @@ void func_map_init_arraycreation(func_map_t& fmap)
353380
fmap[DPNPFuncName::DPNP_FN_ONES][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_ones_c<bool>};
354381
fmap[DPNPFuncName::DPNP_FN_ONES][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_ones_c<std::complex<double>>};
355382

383+
fmap[DPNPFuncName::DPNP_FN_ONES_LIKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_ones_like_c<int>};
384+
fmap[DPNPFuncName::DPNP_FN_ONES_LIKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_ones_like_c<long>};
385+
fmap[DPNPFuncName::DPNP_FN_ONES_LIKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_ones_like_c<float>};
386+
fmap[DPNPFuncName::DPNP_FN_ONES_LIKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_ones_like_c<double>};
387+
fmap[DPNPFuncName::DPNP_FN_ONES_LIKE][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_ones_like_c<bool>};
388+
fmap[DPNPFuncName::DPNP_FN_ONES_LIKE][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_ones_like_c<std::complex<double>>};
389+
356390
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_zeros_c<int>};
357391
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_zeros_c<long>};
358392
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_zeros_c<float>};
359393
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_zeros_c<double>};
360394
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_zeros_c<bool>};
361395
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_ones_c<std::complex<double>>};
362396

397+
fmap[DPNPFuncName::DPNP_FN_ZEROS_LIKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_zeros_like_c<int>};
398+
fmap[DPNPFuncName::DPNP_FN_ZEROS_LIKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_zeros_like_c<long>};
399+
fmap[DPNPFuncName::DPNP_FN_ZEROS_LIKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_zeros_like_c<float>};
400+
fmap[DPNPFuncName::DPNP_FN_ZEROS_LIKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_zeros_like_c<double>};
401+
fmap[DPNPFuncName::DPNP_FN_ZEROS_LIKE][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_zeros_like_c<bool>};
402+
fmap[DPNPFuncName::DPNP_FN_ZEROS_LIKE][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_ones_like_c<std::complex<double>>};
403+
363404
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_tril_c<int>};
364405
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_tril_c<long>};
365406
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
8282
DPNP_FN_FLOOR_DIVIDE
8383
DPNP_FN_FMOD
8484
DPNP_FN_FULL
85+
DPNP_FN_FULL_LIKE
8586
DPNP_FN_HYPOT
8687
DPNP_FN_INITVAL
8788
DPNP_FN_INV
@@ -103,6 +104,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
103104
DPNP_FN_MODF
104105
DPNP_FN_MULTIPLY
105106
DPNP_FN_ONES
107+
DPNP_FN_ONES_LIKE
106108
DPNP_FN_PLACE
107109
DPNP_FN_POWER
108110
DPNP_FN_PROD
@@ -168,6 +170,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
168170
DPNP_FN_TRUNC
169171
DPNP_FN_VAR
170172
DPNP_FN_ZEROS
173+
DPNP_FN_ZEROS_LIKE
171174

172175
cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncType": # need this namespace for Enum import
173176
cdef enum DPNPFuncType "DPNPFuncType":

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,18 @@ __all__ += [
4444
"dpnp_copy",
4545
"dpnp_diag",
4646
"dpnp_full",
47+
"dpnp_full_like",
4748
"dpnp_geomspace",
4849
"dpnp_linspace",
4950
"dpnp_logspace",
5051
"dpnp_meshgrid",
5152
"dpnp_ones",
53+
"dpnp_ones_like",
5254
"dpnp_tri",
5355
"dpnp_tril",
5456
"dpnp_triu",
55-
"dpnp_zeros"
57+
"dpnp_zeros",
58+
"dpnp_zeros_like"
5659
]
5760

5861

@@ -112,6 +115,28 @@ cpdef dparray dpnp_full(result_shape, value_in, result_dtype):
112115
return result
113116

114117

118+
cpdef dparray dpnp_full_like(result_shape, value_in, result_dtype):
119+
# Convert string type names (dparray.dtype) to C enum DPNPFuncType
120+
cdef DPNPFuncType dtype_in = dpnp_dtype_to_DPNPFuncType(result_dtype)
121+
122+
# get the FPTR data structure
123+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_FULL_LIKE, dtype_in, DPNP_FT_NONE)
124+
125+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
126+
# Create single-element input array with type given by FPTR data
127+
cdef dparray_shape_type shape_in = (1,)
128+
cdef dparray array_in = dparray(shape_in, dtype=result_type)
129+
array_in[0] = value_in
130+
# Create result array with type given by FPTR data
131+
cdef dparray result = dparray(result_shape, dtype=result_type)
132+
133+
cdef fptr_1in_1out_t func = <fptr_1in_1out_t > kernel_data.ptr
134+
# Call FPTR function
135+
func(array_in.get_data(), result.get_data(), result.size)
136+
137+
return result
138+
139+
115140
cpdef dparray dpnp_geomspace(start, stop, num, endpoint, dtype, axis):
116141
cdef dparray result = dparray(num, dtype=dtype)
117142

@@ -222,6 +247,10 @@ cpdef dparray dpnp_ones(result_shape, result_dtype):
222247
return call_fptr_1out(DPNP_FN_ONES, result_shape, result_dtype)
223248

224249

250+
cpdef dparray dpnp_ones_like(result_shape, result_dtype):
251+
return call_fptr_1out(DPNP_FN_ONES_LIKE, result_shape, result_dtype)
252+
253+
225254
cpdef dparray dpnp_tri(N, M, k, dtype):
226255
cdef dparray result
227256

@@ -284,3 +313,7 @@ cpdef dparray dpnp_triu(dparray m, int k):
284313

285314
cpdef dparray dpnp_zeros(result_shape, result_dtype):
286315
return call_fptr_1out(DPNP_FN_ZEROS, result_shape, result_dtype)
316+
317+
318+
cpdef dparray dpnp_zeros_like(result_shape, result_dtype):
319+
return call_fptr_1out(DPNP_FN_ZEROS_LIKE, result_shape, result_dtype)

dpnp/dpnp_iface_arraycreation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def full_like(x1, fill_value, dtype=None, order='C', subok=False, shape=None):
665665
_shape = shape if shape is not None else x1.shape
666666
_dtype = dtype if dtype is not None else x1.dtype
667667

668-
return dpnp_full(_shape, fill_value, _dtype)
668+
return dpnp_full_like(_shape, fill_value, _dtype)
669669

670670
return numpy.full_like(x1, fill_value, dtype, order, subok, shape)
671671

@@ -1031,7 +1031,7 @@ def ones_like(x1, dtype=None, order='C', subok=False, shape=None):
10311031
_shape = shape if shape is not None else x1.shape
10321032
_dtype = dtype if dtype is not None else x1.dtype
10331033

1034-
return dpnp_ones(_shape, _dtype)
1034+
return dpnp_ones_like(_shape, _dtype)
10351035

10361036
return numpy.ones_like(x1, dtype, order, subok, shape)
10371037

@@ -1202,6 +1202,6 @@ def zeros_like(x1, dtype=None, order='C', subok=False, shape=None):
12021202
_shape = shape if shape is not None else x1.shape
12031203
_dtype = dtype if dtype is not None else x1.dtype
12041204

1205-
return dpnp_zeros(_shape, _dtype)
1205+
return dpnp_zeros_like(_shape, _dtype)
12061206

12071207
return numpy.zeros_like(x1, dtype, order, subok, shape)

0 commit comments

Comments
 (0)