Skip to content

Commit deef7ea

Browse files
authored
Add NumPy.ndarray as return type (#840)
1 parent 99183ae commit deef7ea

File tree

3 files changed

+17
-29
lines changed

3 files changed

+17
-29
lines changed

dpnp/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@
3737
'''
3838
Explicitly use GPU for SYCL queue
3939
'''
40+
41+
__DPNP_OUTPUT_NUMPY__ = int(os.getenv('DPNP_OUTPUT_NUMPY', 0))
42+
'''
43+
Explicitly use NumPy.ndarray as return type for creation functions
44+
'''

dpnp/dpnp_utils/dpnp_algo_utils.pxd

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,9 @@ cdef DPNPFuncType get_output_c_type(DPNPFuncName funcID,
159159
Calculate output array type by 'out' and 'dtype' cast parameters
160160
"""
161161

162-
cdef dparray create_output_array(shape_type_c output_shape,
163-
DPNPFuncType c_type,
164-
object requested_out)
165-
"""
166-
Create output array based on shape, type and 'out' parameters
167-
"""
168-
169162
cdef dpnp_descriptor create_output_descriptor(shape_type_c output_shape,
170163
DPNPFuncType c_type,
171164
dpnp_descriptor requested_out)
172165
"""
173-
Same as "create_output_array" but output is "dpnp_descriptor"
166+
Create output dpnp_descriptor based on shape, type and 'out' parameters
174167
"""

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -349,37 +349,27 @@ cdef DPNPFuncType get_output_c_type(DPNPFuncName funcID,
349349
checker_throw_value_error("get_output_c_type", "dtype and out", requested_dtype, requested_out)
350350

351351

352-
cdef dparray create_output_array(shape_type_c output_shape, DPNPFuncType c_type, object requested_out):
353-
"""
354-
TODO This function needs to be deleted. Replace with create_output_descriptor()
355-
"""
356-
357-
cdef dparray result
358-
359-
if requested_out is None:
360-
""" Create DPNP array """
361-
result = dparray(output_shape, dtype=dpnp_DPNPFuncType_to_dtype( < size_t > c_type))
362-
else:
363-
""" Based on 'out' parameter """
364-
if (output_shape != requested_out.shape):
365-
checker_throw_value_error("create_output_array", "out.shape", requested_out.shape, output_shape)
366-
result = requested_out
367-
368-
return result
369-
370352
cdef dpnp_descriptor create_output_descriptor(shape_type_c output_shape,
371353
DPNPFuncType c_type,
372354
dpnp_descriptor requested_out):
373355
cdef dpnp_descriptor result_desc
374356

375357
if requested_out is None:
376-
""" Create DPNP array """
377-
result = dparray(output_shape, dtype=dpnp_DPNPFuncType_to_dtype( < size_t > c_type))
358+
result = None
359+
result_dtype = dpnp_DPNPFuncType_to_dtype( < size_t > c_type)
360+
if config.__DPNP_OUTPUT_NUMPY__:
361+
""" Create NumPy ndarray """
362+
# TODO need to use "buffer=" parameter to use SYCL aware memory
363+
result = numpy.ndarray(output_shape, dtype=result_dtype)
364+
else:
365+
""" Create DPNP array """
366+
result = dparray(output_shape, dtype=result_dtype)
367+
378368
result_desc = dpnp_descriptor(result)
379369
else:
380370
""" Based on 'out' parameter """
381371
if (output_shape != requested_out.shape):
382-
checker_throw_value_error("create_output_array", "out.shape", requested_out.shape, output_shape)
372+
checker_throw_value_error("create_output_descriptor", "out.shape", requested_out.shape, output_shape)
383373

384374
if isinstance(requested_out, dpnp_descriptor):
385375
result_desc = requested_out

0 commit comments

Comments
 (0)