Skip to content

Commit cfd6116

Browse files
authored
Use desc in linalg (#762)
1 parent 276fafb commit cfd6116

File tree

3 files changed

+66
-69
lines changed

3 files changed

+66
-69
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ cpdef dparray dpnp_not_equal(object input1, object input2)
271271
"""
272272
Linear algebra
273273
"""
274-
cpdef dparray dpnp_dot(dparray in_array1, dparray in_array2)
275-
cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2)
274+
cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2)
275+
cpdef dparray dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_array2, dparray out=*)
276276

277277

278278
"""

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ __all__ += [
4747
ctypedef void(*fptr_2in_1out_shapes_t)(void *, void * , void * , size_t * , size_t * , size_t * , size_t)
4848

4949

50-
cpdef dparray dpnp_dot(dparray in_array1, dparray in_array2):
50+
cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
5151

5252
cdef dparray_shape_type shape1, shape2
5353

@@ -98,7 +98,7 @@ cpdef dparray dpnp_dot(dparray in_array1, dparray in_array2):
9898
return result
9999

100100

101-
cpdef dparray dpnp_inner(dparray array1, dparray array2):
101+
cpdef dparray dpnp_inner(dpnp_descriptor array1, dpnp_descriptor array2):
102102
result_type = numpy.promote_types(array1.dtype, array1.dtype)
103103

104104
assert(len(array1.shape) == len(array2.shape))
@@ -158,7 +158,7 @@ cpdef dparray dpnp_inner(dparray array1, dparray array2):
158158
return result
159159

160160

161-
cpdef dparray dpnp_kron(dparray in_array1, dparray in_array2):
161+
cpdef dparray dpnp_kron(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
162162
cdef size_t ndim = max(in_array1.ndim, in_array2.ndim)
163163

164164
cdef dparray_shape_type in_array1_shape
@@ -197,12 +197,12 @@ cpdef dparray dpnp_kron(dparray in_array1, dparray in_array2):
197197
return result
198198

199199

200-
cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2, dparray out=None):
200+
cpdef dparray dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_array2, dparray out=None):
201201

202-
cdef vector[Py_ssize_t] shape_result
202+
cdef dparray_shape_type shape_result
203203

204-
cdef vector[Py_ssize_t] shape1 = in_array1.shape
205-
cdef vector[Py_ssize_t] shape2 = in_array2.shape
204+
cdef dparray_shape_type shape1 = in_array1.shape
205+
cdef dparray_shape_type shape2 = in_array2.shape
206206

207207
cdef size_t size_m = 0
208208
cdef size_t size_n = 0
@@ -281,7 +281,7 @@ cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2, dparray out=None
281281
return result
282282

283283

284-
cpdef dparray dpnp_outer(dparray array1, dparray array2):
284+
cpdef dparray dpnp_outer(dpnp_descriptor array1, dpnp_descriptor array2):
285285
cdef dparray_shape_type result_shape = (array1.size, array2.size)
286286
result_type = numpy.promote_types(array1.dtype, array1.dtype)
287287

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 56 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,13 @@
4040
"""
4141

4242

43-
import numpy
44-
4543
from dpnp.dpnp_algo import *
46-
from dpnp.dparray import dparray
4744
from dpnp.dpnp_utils import *
4845
import dpnp
4946
import dpnp.config as config
5047

48+
import numpy
49+
5150

5251
__all__ = [
5352
"dot",
@@ -92,15 +91,14 @@ def dot(x1, x2, **kwargs):
9291
9392
"""
9493

95-
is_x1_dparray = isinstance(x1, dparray)
96-
is_x2_dparray = isinstance(x2, dparray)
97-
98-
if (not use_origin_backend(x1) and is_x1_dparray and is_x2_dparray and not kwargs):
99-
dim1 = x1.ndim
100-
dim2 = x2.ndim
94+
x1_desc = dpnp.get_dpnp_descriptor(x1)
95+
x2_desc = dpnp.get_dpnp_descriptor(x2)
96+
if x1_desc and x2_desc and not kwargs:
97+
dim1 = x1_desc.ndim
98+
dim2 = x2_desc.ndim
10199

102-
if not (dim1 >= 2 and dim2 == 1) and not (dim1 >= 2 and dim2 >= 2) and (x1.dtype == x2.dtype):
103-
result = dpnp_dot(x1, x2)
100+
if not (dim1 >= 2 and dim2 == 1) and not (dim1 >= 2 and dim2 >= 2) and (x1_desc.dtype == x2_desc.dtype):
101+
result = dpnp_dot(x1_desc, x2_desc)
104102

105103
# scalar returned
106104
if result.shape == (1,):
@@ -186,16 +184,15 @@ def inner(x1, x2, **kwargs):
186184
187185
"""
188186

189-
is_x1_dparray = isinstance(x1, dparray)
190-
is_x2_dparray = isinstance(x2, dparray)
191-
192-
if (not use_origin_backend(x1) and is_x1_dparray and is_x2_dparray and not kwargs):
193-
return dpnp_inner(x1, x2)
187+
x1_desc = dpnp.get_dpnp_descriptor(x1)
188+
x2_desc = dpnp.get_dpnp_descriptor(x2)
189+
if 0 and x1_desc and x2_desc and not kwargs:
190+
return dpnp_inner(x1_desc, x2_desc)
194191

195192
return call_origin(numpy.inner, x1, x2, **kwargs)
196193

197194

198-
def kron(a, b):
195+
def kron(x1, x2):
199196
"""
200197
Returns the kronecker product of two arrays.
201198
@@ -205,23 +202,15 @@ def kron(a, b):
205202
206203
"""
207204

208-
if not use_origin_backend(a):
209-
if dpnp.isscalar(a):
210-
a = dpnp.array(a)
211-
if dpnp.isscalar(b):
212-
b = dpnp.array(b)
213-
214-
if not isinstance(a, dparray):
215-
pass
216-
elif not isinstance(b, dparray):
217-
pass
218-
else:
219-
return dpnp_kron(a, b)
205+
x1_desc = dpnp.get_dpnp_descriptor(x1)
206+
x2_desc = dpnp.get_dpnp_descriptor(x2)
207+
if x1_desc and x2_desc:
208+
return dpnp_kron(x1_desc, x2_desc)
220209

221-
return call_origin(numpy.kron, a, b)
210+
return call_origin(numpy.kron, x1, x2)
222211

223212

224-
def matmul(in_array1, in_array2, out=None, **kwargs):
213+
def matmul(x1, x2, out=None, **kwargs):
225214
"""
226215
Matrix product of two arrays.
227216
@@ -257,32 +246,42 @@ def matmul(in_array1, in_array2, out=None, **kwargs):
257246
258247
"""
259248

260-
if not use_origin_backend(in_array1) and not kwargs:
261-
if not isinstance(in_array1, dparray):
249+
x1_desc = dpnp.get_dpnp_descriptor(x1)
250+
x2_desc = dpnp.get_dpnp_descriptor(x2)
251+
out_desc = dpnp.get_dpnp_descriptor(x2)
252+
if x1_desc and x2_desc and out_desc and not kwargs:
253+
if x1_desc.size != x2_desc.size:
254+
pass
255+
elif not x1_desc.ndim:
256+
pass
257+
elif not x2_desc.ndim:
262258
pass
263-
elif not isinstance(in_array2, dparray):
259+
elif not x1_desc.size:
264260
pass
265-
elif out is not None and not isinstance(out, dparray):
261+
elif not x2_desc.size:
266262
pass
267263
else:
268-
"""
269-
Cost model checks
270-
"""
271-
272-
dparray1_size = in_array1.size
273-
dparray2_size = in_array2.size
274-
cost_size = 4096 # 2D array shape(64, 64)
275-
276-
if ((in_array1.dtype == numpy.float64) or (in_array1.dtype == numpy.float32)):
264+
if 0:
277265
"""
278-
Floating point types are handled via original math library better than SYCL math library
266+
Cost model checks
279267
"""
280-
cost_size = 262144 # 2D array shape(512, 512)
281268

282-
if (dparray1_size > cost_size) and (dparray2_size > cost_size):
283-
return dpnp_matmul(in_array1, in_array2, out=out)
269+
dparray1_size = x1_desc.size
270+
dparray2_size = x2_desc.size
271+
cost_size = 4096 # 2D array shape(64, 64)
284272

285-
return call_origin(numpy.matmul, in_array1, in_array2, out=out, **kwargs)
273+
if ((x1_desc.dtype == numpy.float64) or (x1_desc.dtype == numpy.float32)):
274+
"""
275+
Floating point types are handled via original math library better than SYCL math library
276+
"""
277+
cost_size = 262144 # 2D array shape(512, 512)
278+
279+
if (dparray1_size > cost_size) and (dparray2_size > cost_size):
280+
return dpnp_matmul(x1_desc, x2_desc, out)
281+
else:
282+
return dpnp_matmul(x1_desc, x2_desc, out)
283+
284+
return call_origin(numpy.matmul, x1, x2, out=out, **kwargs)
286285

287286

288287
def outer(x1, x2, **kwargs):
@@ -314,11 +313,10 @@ def outer(x1, x2, **kwargs):
314313
315314
"""
316315

317-
is_x1_dparray = isinstance(x1, dparray)
318-
is_x2_dparray = isinstance(x2, dparray)
319-
320-
if (not use_origin_backend(x1) and is_x1_dparray and is_x2_dparray and not kwargs):
321-
return dpnp_outer(x1, x2)
316+
x1_desc = dpnp.get_dpnp_descriptor(x1)
317+
x2_desc = dpnp.get_dpnp_descriptor(x2)
318+
if 0 and x1_desc and x2_desc and not kwargs:
319+
return dpnp_outer(x1_desc, x2_desc)
322320

323321
return call_origin(numpy.outer, x1, x2, **kwargs)
324322

@@ -353,11 +351,10 @@ def tensordot(x1, x2, axes=2):
353351
354352
"""
355353

356-
is_x1_dparray = isinstance(x1, dparray)
357-
is_x2_dparray = isinstance(x2, dparray)
358-
359-
if (not use_origin_backend(x1) and is_x1_dparray and is_x2_dparray and (axes == 1)):
360-
return dpnp_tensordot(x1, x2) # dpnp_matmul
354+
x1_desc = dpnp.get_dpnp_descriptor(x1)
355+
x2_desc = dpnp.get_dpnp_descriptor(x2)
356+
if x1_desc and x2_desc and (axes == 1):
357+
return dpnp_tensordot_not_implemented(x1_desc, x2_desc) # dpnp_matmul
361358

362359
return call_origin(numpy.tensordot, x1, x2, axes)
363360

0 commit comments

Comments
 (0)