@@ -31,6 +31,27 @@ def _create_array(buf, numdims, idims, dtype, is_device):
31
31
numdims , ct .pointer (c_dims ), dtype .value ))
32
32
return out_arr
33
33
34
+ def _create_strided_array (buf , numdims , idims , dtype , is_device , offset , strides ):
35
+ out_arr = ct .c_void_p (0 )
36
+ c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
37
+ if offset is None :
38
+ offset = 0
39
+ offset = ct .c_ulonglong (offset )
40
+ if strides is None :
41
+ strides = (1 , idims [0 ], idims [0 ]* idims [1 ], idims [0 ]* idims [1 ]* idims [2 ])
42
+ while len (strides ) < 4 :
43
+ strides = strides + (strides [- 1 ],)
44
+ strides = dim4 (strides [0 ], strides [1 ], strides [2 ], strides [3 ])
45
+ if is_device :
46
+ location = Source .device
47
+ else :
48
+ location = Source .host
49
+ safe_call (backend .get ().af_create_strided_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
50
+ offset , numdims , ct .pointer (c_dims ),
51
+ ct .pointer (strides ), dtype .value ,
52
+ location .value ))
53
+ return out_arr
54
+
34
55
def _create_empty_array (numdims , idims , dtype ):
35
56
out_arr = ct .c_void_p (0 )
36
57
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -352,7 +373,7 @@ class Array(BaseArray):
352
373
353
374
"""
354
375
355
- def __init__ (self , src = None , dims = (0 ,), dtype = None , is_device = False ):
376
+ def __init__ (self , src = None , dims = (0 ,), dtype = None , is_device = False , offset = None , strides = None ):
356
377
357
378
super (Array , self ).__init__ ()
358
379
@@ -409,8 +430,10 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False):
409
430
if (type_char is not None and
410
431
type_char != _type_char ):
411
432
raise TypeError ("Can not create array of requested type from input data type" )
412
-
413
- self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
433
+ if (offset is None and strides is None ):
434
+ self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
435
+ else :
436
+ self .arr = _create_strided_array (buf , numdims , idims , to_dtype [_type_char ], is_device , offset , strides )
414
437
415
438
else :
416
439
0 commit comments