Skip to content

Commit cdd0ce0

Browse files
committed
Add INLINE or INLINABLE pragmas to all storable ranked code
1 parent 502d4d2 commit cdd0ce0

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

Data/Array/Internal.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ traverseT
389389
traverseT sh f a = fmap (fromListT sh) (traverse f (toListT sh a))
390390

391391
-- Fast check if all elements are equal.
392+
{-# INLINABLE allSameT #-}
392393
allSameT :: (Vector v, VecElem v a, Eq a) => ShapeL -> T v a -> Bool
393394
allSameT sh t@(T _ _ v)
394395
| vLength v <= 1 = True
@@ -482,6 +483,7 @@ zipWithLong2 :: (a -> b -> b) -> [a] -> [b] -> [b]
482483
zipWithLong2 f (a:as) (b:bs) = f a b : zipWithLong2 f as bs
483484
zipWithLong2 _ _ bs = bs
484485

486+
{-# INLINABLE padT #-}
485487
padT :: forall v a . (Vector v, VecElem v a) => a -> [(Int, Int)] -> ShapeL -> T v a -> ([Int], T v a)
486488
padT v aps ash at = (ss, fromVectorT ss $ vConcat $ pad' aps ash st at)
487489
where pad' :: [(Int, Int)] -> ShapeL -> [Int] -> T v a -> [v a]
@@ -563,13 +565,15 @@ iotaT n = fromListT [n] [0 .. fromIntegral n - 1] -- TODO: should use V.enumF
563565
-------
564566

565567
-- | Permute the elements of a list, the first argument is indices into the original list.
568+
{-# INLINE permute #-}
566569
permute :: [Int] -> [a] -> [a]
567570
permute is xs = map (xs!!) is
568571

569572
-- | Like 'dropWhile' but at the end of the list.
570573
revDropWhile :: (a -> Bool) -> [a] -> [a]
571574
revDropWhile p = reverse . dropWhile p . reverse
572575

576+
{-# INLINEABLE allSame #-}
573577
allSame :: (Eq a) => [a] -> Bool
574578
allSame [] = True
575579
allSame (x : xs) = all (x ==) xs

Data/Array/Internal/RankedG.hs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ data Array (n :: Nat) v a = A !ShapeL !(T v a)
6969
deriving (Generic, Data)
7070

7171
instance (Vector v, Show a, VecElem v a) => Show (Array n v a) where
72-
{-# INLINABLE showsPrec #-}
7372
showsPrec p a@(A s _) = showParen (p > 10) $
7473
showString "fromList " . showsPrec 11 s . showString " " . showsPrec 11 (toList a)
74+
{-# INLINABLE showsPrec #-}
7575

7676
instance (KnownNat n, Vector v, Read a, VecElem v a) => Read (Array n v a) where
7777
readsPrec p = readParen (p > 10) $ \ r1 ->
@@ -92,6 +92,7 @@ instance (Vector v, Pretty a, VecElem v a) => Pretty (Array n v a) where
9292

9393
instance (NFData (v a)) => NFData (Array n v a) where
9494
rnf (A sh v) = rnf sh `seq` rnf v
95+
{-# INLINE rnf #-}
9596

9697
-- | The number of elements in the array.
9798
-- O(1) time.
@@ -218,7 +219,7 @@ constant :: forall n v a . (Vector v, VecElem v a, KnownNat n) =>
218219
ShapeL -> a -> Array n v a
219220
constant sh | badShape sh = error $ "constant: bad shape: " ++ show sh
220221
| length sh /= valueOf @n = error "constant: rank mismatch"
221-
| otherwise = A sh . constantT sh
222+
| otherwise = A sh . constantT sh
222223

223224
-- | Map over the array elements.
224225
-- O(n) time.
@@ -344,6 +345,7 @@ stride ats (A ash (T ss o v)) = A (str ats ash) (T (zipWith (*) (ats ++ repeat 1
344345
-- | Rotate the array k times along the d'th dimension.
345346
-- E.g., if the array shape is @[2, 3, 2]@, d is 1, and k is 4,
346347
-- the resulting shape will be @[2, 4, 3, 2]@.
348+
{-# INLINABLE rotate #-}
347349
rotate :: forall d p v a.
348350
(KnownNat p, KnownNat d,
349351
Vector v, VecElem v a,
@@ -455,6 +457,7 @@ traverseA
455457
traverseA f (A sh t) = A sh <$> traverseT sh f t
456458

457459
-- | Check if all elements of the array are equal.
460+
{-# INLINE allSameA #-}
458461
allSameA :: (Vector v, VecElem v a, Eq a) => Array r v a -> Bool
459462
allSameA (A sh t) = allSameT sh t
460463

@@ -500,6 +503,7 @@ allA p (A sh t) = allT sh p t
500503
-- and just replicate the data along all other dimensions.
501504
-- The list of dimensions indicies must have the same rank as the argument array
502505
-- and it must be strictly ascending.
506+
{-# INLINABLE broadcast #-}
503507
broadcast :: forall r' r v a .
504508
(HasCallStack, Vector v, VecElem v a, KnownNat r, KnownNat r') =>
505509
[Int] -> ShapeL -> Array r v a -> Array r' v a

Data/Array/Internal/RankedS.hs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,20 @@ size = product . shapeL
9494
-- In the linearization of the array the outermost (i.e. first list element)
9595
-- varies most slowly.
9696
-- O(1) time.
97+
{-# INLINE shapeL #-}
9798
shapeL :: Array n a -> ShapeL
9899
shapeL = G.shapeL . unA
99100

100101
-- | The rank of an array, i.e., the number of dimensions it has,
101102
-- which is the @n@ in @Array n a@.
102103
-- O(1) time.
104+
{-# INLINE rank #-}
103105
rank :: (KnownNat n) => Array n a -> Int
104106
rank = G.rank . unA
105107

106108
-- | Index into an array. Fails if the index is out of bounds.
107109
-- O(1) time.
110+
{-# INLINABLE index #-}
108111
index :: (Unbox a) => Array (1+n) a -> Int -> Array n a
109112
index a = A . G.index (unA a)
110113

@@ -138,6 +141,7 @@ fromVector ss = A . G.fromVector ss
138141
-- This is semantically an identity function, but can have big performance
139142
-- implications.
140143
-- O(n) or O(1) time.
144+
{-# INLINABLE normalize #-}
141145
normalize :: (Unbox a, KnownNat n) => Array n a -> Array n a
142146
normalize = A . G.normalize . unA
143147

@@ -150,20 +154,24 @@ reshape s = A . G.reshape s . unA
150154
-- | Change the size of dimensions with size 1. These dimension can be changed to any size.
151155
-- All other dimensions must remain the same.
152156
-- O(1) time.
157+
{-# INLINABLE stretch #-}
153158
stretch :: ShapeL -> Array n a -> Array n a
154159
stretch s = A . G.stretch s . unA
155160

156161
-- | Change the size of the outermost dimension by replication.
162+
{-# INLINABLE stretchOuter #-}
157163
stretchOuter :: (HasCallStack, 1 <= n) => Int -> Array n a -> Array n a
158164
stretchOuter s = A . G.stretchOuter s . unA
159165

160166
-- | Convert a value to a scalar (rank 0) array.
161167
-- O(1) time.
168+
{-# INLINE scalar #-}
162169
scalar :: (Unbox a) => a -> Array 0 a
163170
scalar = A . G.scalar
164171

165172
-- | Convert a scalar (rank 0) array to a value.
166173
-- O(1) time.
174+
{-# INLINE unScalar #-}
167175
unScalar :: (Unbox a) => Array 0 a -> a
168176
unScalar = G.unScalar . unA
169177

@@ -182,18 +190,21 @@ mapA f = A . G.mapA f . unA
182190

183191
-- | Map over the array elements.
184192
-- O(n) time.
193+
{-# INLINABLE zipWithA #-}
185194
zipWithA :: (Unbox a, Unbox b, Unbox c) =>
186195
(a -> b -> c) -> Array n a -> Array n b -> Array n c
187196
zipWithA f a b = A $ G.zipWithA f (unA a) (unA b)
188197

189198
-- | Map over the array elements.
190199
-- O(n) time.
200+
{-# INLINABLE zipWith3A #-}
191201
zipWith3A :: (Unbox a, Unbox b, Unbox c, Unbox d) =>
192202
(a -> b -> c -> d) -> Array n a -> Array n b -> Array n c -> Array n d
193203
zipWith3A f a b c = A $ G.zipWith3A f (unA a) (unA b) (unA c)
194204

195205
-- | Pad each dimension on the low and high side with the given value.
196206
-- O(n) time.
207+
{-# INLINABLE pad #-}
197208
pad :: (Unbox a, KnownNat n) => [(Int, Int)] -> a -> Array n a -> Array n a
198209
pad ps v = A . G.pad ps v . unA
199210

@@ -215,6 +226,7 @@ append x y = A $ G.append (unA x) (unA y)
215226
-- | Concatenate a number of arrays into a single array.
216227
-- Fails if any, but the outer, dimensions differ.
217228
-- O(n) time.
229+
{-# INLINABLE concatOuter #-}
218230
concatOuter :: (Unbox a, KnownNat n) => [Array n a] -> Array n a
219231
concatOuter = A . G.concatOuter . coerce
220232

@@ -244,19 +256,22 @@ unravel = R.A . G.mapA A . G.unravel . unA
244256
--
245257
-- If the window parameter @ws = [w1,...,wk]@ and @wa = window ws a@ then
246258
-- @wa `index` i1 ... `index` ik == slice [(i1,w1),...,(ik,wk)] a@.
259+
{-# INLINABLE window #-}
247260
window :: (KnownNat n, KnownNat n') => [Int] -> Array n a -> Array n' a
248261
window ws = A . G.window ws . unA
249262

250263
-- | Stride the outermost dimensions.
251264
-- E.g., if the array shape is @[10,12,8]@ and the strides are
252265
-- @[2,2]@ then the resulting shape will be @[5,6,8]@.
253266
-- O(1) time.
267+
{-# INLINABLE stride #-}
254268
stride :: [Int] -> Array n a -> Array n a
255269
stride ws = A . G.stride ws . unA
256270

257271
-- | Rotate the array k times along the d'th dimension.
258272
-- E.g., if the array shape is @[2, 3, 2]@, d is 1, and k is 4,
259273
-- the resulting shape will be @[2, 4, 3, 2]@.
274+
{-# INLINABLE rotate #-}
260275
rotate :: forall d p a.
261276
(KnownNat p, KnownNat d, Unbox a,
262277
-- Nonsense
@@ -276,13 +291,15 @@ rotate k = A . G.rotate @d @p k . unA
276291
-- The extracted slice must fall within the array dimensions.
277292
-- E.g. @slice [1,2] (fromList [4] [1,2,3,4]) == [2,3]@.
278293
-- O(1) time.
294+
{-# INLINABLE slice #-}
279295
slice :: [(Int, Int)] -> Array n a -> Array n a
280296
slice ss = A . G.slice ss . unA
281297

282298
-- | Apply a function to the subarrays /n/ levels down and make
283299
-- the results into an array with the same /n/ outermost dimensions.
284300
-- The /n/ must not exceed the rank of the array.
285301
-- O(1) time.
302+
{-# INLINABLE rerank #-}
286303
rerank :: forall n i o a b .
287304
(Unbox a, Unbox b, KnownNat n, KnownNat o, KnownNat (n+o), KnownNat (1+o)) =>
288305
(Array i a -> Array o b) -> Array (n+i) a -> Array (n+o) b
@@ -292,37 +309,44 @@ rerank f = A . G.rerank (unA . f . A) . unA
292309
-- the results into an array with the same /n/ outermost dimensions.
293310
-- The /n/ must not exceed the rank of the array.
294311
-- O(n) time.
312+
{-# INLINABLE rerank2 #-}
295313
rerank2 :: forall n i o a b c .
296314
(Unbox a, Unbox b, Unbox c, KnownNat n, KnownNat o, KnownNat (n+o), KnownNat (1+o)) =>
297315
(Array i a -> Array i b -> Array o c) -> Array (n+i) a -> Array (n+i) b -> Array (n+o) c
298316
rerank2 f ta tb = A $ G.rerank2 @n (\ a b -> unA $ f (A a) (A b)) (unA ta) (unA tb)
299317

300318
-- | Reverse the given dimensions, with the outermost being dimension 0.
301319
-- O(1) time.
320+
{-# INLINABLE rev #-}
302321
rev :: [Int] -> Array n a -> Array n a
303322
rev rs = A . G.rev rs . unA
304323

305324
-- | Reduce all elements of an array into a rank 0 array.
306325
-- To reduce parts use 'rerank' and 'transpose' together with 'reduce'.
307326
-- O(n) time.
327+
{-# INLINABLE reduce #-}
308328
reduce :: (Unbox a) => (a -> a -> a) -> a -> Array n a -> Array 0 a
309329
reduce f z = A . G.reduce f z . unA
310330

311331
-- | Constrained version of 'foldr' for Arrays.
332+
{-# INLINABLE foldrA #-}
312333
foldrA :: (Unbox a, Unbox b) => (a -> b -> b) -> b -> Array n a -> b
313334
foldrA f z = G.foldrA f z . unA
314335

315336
-- | Constrained version of 'traverse' for Arrays.
337+
{-# INLINABLE traverseA #-}
316338
traverseA
317339
:: (Unbox a, Unbox b, Applicative f)
318340
=> (a -> f b) -> Array n a -> f (Array n b)
319341
traverseA f = fmap A . G.traverseA f . unA
320342

321343
-- | Check if all elements of the array are equal.
344+
{-# INLINABLE allSameA #-}
322345
allSameA :: (Unbox a, Eq a) => Array n a -> Bool
323346
allSameA = G.allSameA . unA
324347

325-
instance (KnownNat r, Arbitrary a, Unbox a) => Arbitrary (Array r a) where arbitrary = A <$> arbitrary
348+
instance (KnownNat r, Arbitrary a, Unbox a) => Arbitrary (Array r a) where
349+
arbitrary = A <$> arbitrary
326350

327351
-- | Sum of all elements.
328352
{-# INLINE sumA #-}
@@ -358,6 +382,7 @@ allA p = G.allA p . unA
358382
-- and just replicate the data along all other dimensions.
359383
-- The list of dimensions indicies must have the same rank as the argument array
360384
-- and it must be strictly ascending.
385+
{-# INLINABLE broadcast #-}
361386
broadcast :: forall r' r a .
362387
(HasCallStack, Unbox a, KnownNat r, KnownNat r') =>
363388
[Int] -> ShapeL -> Array r a -> Array r' a

0 commit comments

Comments
 (0)