Skip to content

Commit ad0da18

Browse files
committed
Update documentation on Data module.
1 parent d03b1aa commit ad0da18

File tree

4 files changed

+137
-29
lines changed

4 files changed

+137
-29
lines changed

shell.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ let
44
in
55
pkgs.lib.overrideDerivation pkg (drv: {
66
shellHook = ''
7+
export AF_PRINT_ERRORS=1
78
export PATH=$PATH:${pkgs.haskellPackages.doctest}/bin
89
export PATH=$PATH:${pkgs.haskellPackages.cabal-install}/bin
910
function ghcid () {

src/ArrayFire/Data.hs

Lines changed: 133 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ constant dims val =
178178
n = fromIntegral (length dims')
179179

180180
-- | Creates a range of values in an Array
181+
--
181182
-- >>> range @Double [10] (-1)
182183
-- ArrayFire Array
183184
-- [10 1 1 1]
@@ -187,8 +188,8 @@ range
187188
. AFType a
188189
=> [Int]
189190
-> Int
190-
-> IO (Array a)
191-
range dims (fromIntegral -> k) = do
191+
-> Array a
192+
range dims (fromIntegral -> k) = unsafePerformIO $ do
192193
ptr <- alloca $ \ptrPtr -> mask_ $ do
193194
withArray (fromIntegral <$> dims) $ \dimArray -> do
194195
throwAFError =<< af_range ptrPtr n dimArray k typ
@@ -201,28 +202,53 @@ range dims (fromIntegral -> k) = do
201202
n = fromIntegral (length dims)
202203
typ = afType (Proxy @ a)
203204

205+
-- | Create an sequence [0, dims.elements() - 1] and modify to specified dimensions dims and then tile it according to tile_dims.
206+
--
207+
-- <http://arrayfire.org/docs/group__data__func__iota.htm>
208+
--
209+
-- >>> iota @Double [5,3] []
210+
-- ArrayFire Array
211+
-- [5 3 1 1]
212+
-- 0.0000 1.0000 2.0000 3.0000 4.0000
213+
-- 5.0000 6.0000 7.0000 8.0000 9.0000
214+
-- 10.0000 11.0000 12.0000 13.0000 14.0000
215+
--
216+
-- >>> iota @Double [5,3] [1,2]
217+
-- ArrayFire Array
218+
-- [5 6 1 1]
219+
-- 0.0000 1.0000 2.0000 3.0000 4.0000
220+
-- 5.0000 6.0000 7.0000 8.0000 9.0000
221+
-- 10.0000 11.0000 12.0000 13.0000 14.0000
222+
-- 0.0000 1.0000 2.0000 3.0000 4.0000
223+
-- 5.0000 6.0000 7.0000 8.0000 9.0000
224+
-- 10.0000 11.0000 12.0000 13.0000 14.0000
204225
iota
205226
:: forall a . AFType a
206-
=> [Int] -> [Int] -> IO (Array a)
207-
iota dims tdims = do
227+
=> [Int]
228+
-- ^ is the array containing sizes of the dimension
229+
-> [Int]
230+
-- ^ is array containing the number of repetitions of the unit dimensions
231+
-> Array a
232+
-- ^ is the generated array
233+
iota dims tdims = unsafePerformIO $ do
234+
let dims' = take 4 (dims ++ repeat 1)
235+
tdims' = take 4 (tdims ++ repeat 1)
208236
ptr <- alloca $ \ptrPtr -> mask_ $ do
209237
zeroOutArray ptrPtr
210-
withArray (fromIntegral <$> dims) $ \dimArray ->
211-
withArray (fromIntegral <$> tdims) $ \tdimArray -> do
212-
throwAFError =<< af_iota ptrPtr n dimArray tn tdimArray typ
238+
withArray (fromIntegral <$> dims') $ \dimArray ->
239+
withArray (fromIntegral <$> tdims') $ \tdimArray -> do
240+
throwAFError =<< af_iota ptrPtr 4 dimArray 4 tdimArray typ
213241
peek ptrPtr
214242
Array <$>
215243
newForeignPtr
216244
af_release_array_finalizer
217245
ptr
218246
where
219-
n = fromIntegral (length dims)
220-
tn = fromIntegral (length tdims)
221247
typ = afType (Proxy @ a)
222248

223249
-- | Creates the identity `Array` from given dimensions
224250
--
225-
-- >>> identity [2,2] 2.0
251+
-- >>> identity [2,2]
226252
-- ArrayFire Array
227253
-- [2 2 1 1]
228254
-- 1.0000 0.0000
@@ -233,9 +259,10 @@ identity
233259
-- ^ Dimensions
234260
-> Array a
235261
identity dims = unsafePerformIO . mask_ $ do
262+
let dims' = take 4 (dims ++ repeat 1)
236263
ptr <- alloca $ \ptrPtr -> mask_ $ do
237264
zeroOutArray ptrPtr
238-
withArray (fromIntegral <$> dims) $ \dimArray -> do
265+
withArray (fromIntegral <$> dims') $ \dimArray -> do
239266
throwAFError =<< af_identity ptrPtr n dimArray typ
240267
peek ptrPtr
241268
Array <$>
@@ -246,14 +273,29 @@ identity dims = unsafePerformIO . mask_ $ do
246273
n = fromIntegral (length dims)
247274
typ = afType (Proxy @ a)
248275

276+
-- | Create a diagonal matrix from input array when extract is set to false
277+
--
278+
-- >>> diagCreate (vector @Double 2 [1..]) 0
279+
-- ArrayFire Array
280+
-- [2 2 1 1]
281+
-- 1.0000 0.0000
282+
-- 0.0000 2.0000
249283
diagCreate
250284
:: AFType (a :: *)
251285
=> Array a
286+
-- ^ is the input array which is the diagonal
252287
-> Int
288+
-- ^ is the diagonal index
253289
-> Array a
254290
diagCreate x (fromIntegral -> n) =
255291
x `op1` (\p a -> af_diag_create p a n)
256292

293+
-- | Create a diagonal matrix from input array when extract is set to false
294+
--
295+
-- >>> diagExtract (matrix @Double (2,2) [[1,2],[3,4]]) 0
296+
-- ArrayFire Array
297+
-- [2 1 1 1]
298+
-- 1.0000 4.0000
257299
diagExtract
258300
:: AFType (a :: *)
259301
=> Array a
@@ -262,13 +304,29 @@ diagExtract
262304
diagExtract x (fromIntegral -> n) =
263305
x `op1` (\p a -> af_diag_extract p a n)
264306

307+
-- | Join two Arrays together along a specified dimension
308+
--
309+
-- >>> join 0 (matrix @Double (2,2) [[1,2],[3,4]]) (matrix @Double (2,2) [[5,6],[7,8]])
310+
-- ArrayFire Array
311+
-- [4 2 1 1]
312+
-- 1.0000 2.0000 5.0000 6.0000
313+
-- 3.0000 4.0000 7.0000 8.0000
314+
--
265315
join
266316
:: Int
267317
-> Array (a :: *)
268318
-> Array a
269319
-> Array a
270320
join (fromIntegral -> n) arr1 arr2 = op2 arr1 arr2 (\p a b -> af_join p n a b)
271321

322+
-- | Join many Arrays together along a specified dimension
323+
--
324+
-- *FIX ME*
325+
-- >>> joinMany 0 [1,2,3]
326+
-- ArrayFire Array
327+
-- [3 1 1 1]
328+
-- 1.0000 2.0000 3.0000
329+
--
272330
joinMany
273331
:: Int
274332
-> [Array a]
@@ -288,26 +346,48 @@ joinMany (fromIntegral -> n) arrays = unsafePerformIO . mask_ $ do
288346
where
289347
nArrays = fromIntegral (length arrays)
290348

349+
-- | Tiles an Array according to specified dimensions
350+
--
351+
-- >>> tile @Double (scalar 22.0) [5,5]
352+
-- ArrayFire Array
353+
-- [5 5 1 1]
354+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
355+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
356+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
357+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
358+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
291359
tile
292360
:: Array (a :: *)
293-
-> Int
294-
-> Int
295-
-> Int
296-
-> Int
361+
-> [Int]
297362
-> Array a
298-
tile a (fromIntegral -> x) (fromIntegral -> y) (fromIntegral -> z) (fromIntegral -> w) =
299-
a `op1` (\p k -> af_tile p k x y z w)
363+
tile a (take 4 . (++repeat 1) -> [x,y,z,w]) =
364+
a `op1` (\p k -> af_tile p k (fromIntegral x) (fromIntegral y) (fromIntegral z) (fromIntegral w))
365+
tile _ _ = error "impossible"
300366

367+
-- | Reorders an Array according to newly specified dimensions
368+
--
369+
-- *FIX ME*
370+
-- >>> reorder @Double (scalar 22.0) [5,5]
371+
-- ArrayFire Array
372+
-- [5 5 1 1]
373+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
374+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
375+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
376+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
377+
-- 22.0000 22.0000 22.0000 22.0000 22.0000
301378
reorder
302379
:: Array (a :: *)
303-
-> Int
304-
-> Int
305-
-> Int
306-
-> Int
380+
-> [Int]
307381
-> Array a
308-
reorder a (fromIntegral -> x) (fromIntegral -> y) (fromIntegral -> z) (fromIntegral -> w) =
309-
a `op1` (\p k -> af_tile p k x y z w)
382+
reorder a (take 4 . (++ repeat 0) -> [x,y,z,w]) =
383+
a `op1` (\p k -> af_reorder p k (fromIntegral x) (fromIntegral y) (fromIntegral z) (fromIntegral w))
384+
reorder _ _ = error "impossible"
310385

386+
-- | Shift elements in an Array along a specified dimension (elements will wrap).
387+
-- >>> shift (vector @Double 4 [1..]) 2 0 0 0
388+
-- ArrayFire Array
389+
-- [4 1 1 1]
390+
-- 3.0000 4.0000 1.0000 2.0000
311391
shift
312392
:: Array (a :: *)
313393
-> Int
@@ -318,12 +398,20 @@ shift
318398
shift a (fromIntegral -> x) (fromIntegral -> y) (fromIntegral -> z) (fromIntegral -> w) =
319399
a `op1` (\p k -> af_shift p k x y z w)
320400

401+
-- | Modify dimensions of array
402+
--
403+
-- >>> moddims (vector @Double 3 [1..]) [1,3]
404+
-- ArrayFire Array
405+
-- [1 3 1 1]
406+
-- 1.0000
407+
-- 2.0000
408+
-- 3.0000
321409
moddims
322410
:: forall a
323-
. [Int]
324-
-> Array (a :: *)
411+
. Array (a :: *)
412+
-> [Int]
325413
-> Array a
326-
moddims dims (Array fptr) =
414+
moddims (Array fptr) dims =
327415
unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do
328416
newPtr <- alloca $ \aPtr -> do
329417
zeroOutArray aPtr
@@ -334,11 +422,30 @@ moddims dims (Array fptr) =
334422
where
335423
n = fromIntegral (length dims)
336424

425+
-- | Flatten an Array into a single dimension
426+
--
427+
-- >>> flat (matrix @Double (2,2) [[1..],[1..]])
428+
-- ArrayFire Array
429+
-- [4 1 1 1]
430+
-- 1.0000 2.0000 1.0000 2.0000
337431
flat
338432
:: Array a
339433
-> Array a
340434
flat = (`op1` af_flat)
341435

436+
-- | Flip the values of an Array along a specified dimension
437+
--
438+
-- >>> matrix @Double (2,2) [[2,2],[3,3]]
439+
-- ArrayFire Array
440+
-- [2 2 1 1]
441+
-- 2.0000 2.0000
442+
-- 3.0000 3.0000
443+
--
444+
-- >>> A.flip (matrix @Double (2,2) [[2,2],[3,3]]) 1
445+
-- ArrayFire Array
446+
-- [2 2 1 1]
447+
-- 3.0000 3.0000
448+
-- 2.0000 2.0000
342449
flip
343450
:: Array a
344451
-> Int

src/ArrayFire/Orphans.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ instance (Ord a, AFType a) => Ord (Array a) where
4545
x <= y = toEnum . fromIntegral $ A.getScalar @CBool @a (A.le x y)
4646
x >= y = toEnum . fromIntegral $ A.getScalar @CBool @a (A.ge x y)
4747

48-
instance AFType a => Semigroup (Array a) where
49-
x <> y = A.mul x y
48+
instance (Fractional a, AFType a) => Semigroup (Array a) where
49+
(<>) = A.mul
5050

5151
instance Show (Array a) where
5252
show = arrayString

src/ArrayFire/Random.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ rand dims f = mask_ $ do
275275
--
276276
randn
277277
:: forall a
278-
. AFType a
278+
. (AFType a, Fractional a)
279279
=> [Int]
280280
-- ^ Dimensions of random array
281281
-> IO (Array a)

0 commit comments

Comments
 (0)