Skip to content

Commit e45b24e

Browse files
authored
Merge pull request #19 from Mikolaj/inlineable
Add INLINE or INLINABLE pragmas to all storable ranked code
2 parents be7d4f8 + a23efbe commit e45b24e

File tree

8 files changed

+113
-57
lines changed

8 files changed

+113
-57
lines changed

Data/Array/Internal.hs

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class Vector v where
5454
vLength :: (VecElem v a) => v a -> Int
5555
vToList :: (VecElem v a) => v a -> [a]
5656
vFromList :: (VecElem v a) => [a] -> v a
57+
vFromListN:: (VecElem v a) => Int -> [a] -> v a
5758
vSingleton:: (VecElem v a) => a -> v a
5859
vReplicate:: (VecElem v a) => Int -> a -> v a
5960
vMap :: (VecElem v a, VecElem v b) => (a -> b) -> v a -> v b
@@ -84,6 +85,7 @@ instance Vector [] where
8485
vLength = length
8586
vToList = id
8687
vFromList = id
88+
vFromListN _ = id
8789
vSingleton = pure
8890
vReplicate = replicate
8991
vMap = map
@@ -209,41 +211,50 @@ unScalarT (T _ o v) = vIndex v o
209211
constantT :: (Vector v, VecElem v a) => ShapeL -> a -> T v a
210212
constantT sh x = T (map (const 0) sh) 0 (vSingleton x)
211213

212-
-- TODO: change to return a list of vectors.
213-
-- Convert an array to a vector in the natural order.
214-
{-# INLINE toVectorT #-}
215-
toVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
216-
toVectorT sh a@(T ats ao v) =
214+
-- Convert an array to a list of vectors, which together contain
215+
-- all the elements in the natural order.
216+
-- An invariant: if the input array is non-empty the returned list
217+
-- will have no empty vectors.
218+
-- The minimum/maximum operations rely on this invariant.
219+
{-# INLINE toVectorListT #-}
220+
toVectorListT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [v a]
221+
toVectorListT sh a@(T ats ao v) =
217222
let l : ts' = getStridesT sh
218223
-- Are strides ok from this point?
219224
oks = scanr (&&) True (zipWith (==) ats ts')
220-
loop _ [] _ o =
221-
DL.singleton (vSlice o 1 v)
222-
loop (b:bs) (s:ss) (t:ts) o =
225+
loop (b:bs) (s:ss) (t:ts) !o =
223226
if b then
224227
-- All strides normal from this point,
225228
-- so just take a slice of the underlying vector.
226229
DL.singleton (vSlice o (s*t) v)
227230
else
228231
-- Strides are not normal, collect slices.
229232
DL.concat [ loop bs ss ts (i*t + o) | i <- [0 .. s-1] ]
230-
loop _ _ _ _ = error "impossible"
231-
in if head oks && vLength v == l then
233+
loop _ _ _ _ = error "impossible" -- due to how @loop@ is called
234+
in if ats == ts' && vLength v == l then
232235
-- All strides are normal, return entire vector
233-
v
234-
else if oks !! length sh then -- Special case for speed.
236+
[v]
237+
else if null sh then
238+
[vSlice ao 1 v]
239+
else if oks !! (length sh - 1) then -- Special case for speed.
235240
-- Innermost dimension is normal, so slices are non-trivial.
236-
vConcat $ DL.toList $ loop oks sh ats ao
241+
DL.toList $ loop oks sh ats ao
237242
else
238243
-- All slices would have length 1, going via a list is faster.
239-
vFromList $ toListT sh a
244+
[vFromListN l $ toListT sh a]
245+
246+
{-# INLINE toVectorT #-}
247+
toVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
248+
toVectorT sh a = case toVectorListT sh a of
249+
[v] -> v
250+
l -> vConcat l
240251

241-
-- Convert to a vector containing the right elements,
252+
-- Convert to a list of vectors containing altogether the right elements,
242253
-- but not necessarily in the right order.
243254
-- This is used for reduction with commutative&associative operations.
244-
{-# INLINE toUnorderedVectorT #-}
245-
toUnorderedVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
246-
toUnorderedVectorT sh a@(T ats ao v) =
255+
{-# INLINE toUnorderedVectorListT #-}
256+
toUnorderedVectorListT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [v a]
257+
toUnorderedVectorListT sh a@(T ats ao v) =
247258
-- Figure out if the array maps onto some contiguous slice of the vector.
248259
-- Do this by checking if a transposition of the array corresponds to
249260
-- normal strides.
@@ -256,9 +267,15 @@ toUnorderedVectorT sh a@(T ats ao v) =
256267
l : ts' = getStridesT sh'
257268
in
258269
if ats' == ts' then
259-
vSlice ao l v
270+
[vSlice ao l v]
260271
else
261-
toVectorT sh a
272+
toVectorListT sh a
273+
274+
{-# INLINE toUnorderedVectorT #-}
275+
toUnorderedVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
276+
toUnorderedVectorT sh a = case toUnorderedVectorListT sh a of
277+
[v] -> v
278+
l -> vConcat l
262279

263280
-- Convert from a vector.
264281
{-# INLINE fromVectorT #-}
@@ -268,7 +285,7 @@ fromVectorT sh = T (tail $ getStridesT sh) 0
268285
-- Convert from a list
269286
{-# INLINE fromListT #-}
270287
fromListT :: (Vector v, VecElem v a) => [Int] -> [a] -> T v a
271-
fromListT sh = fromVectorT sh . vFromList
288+
fromListT sh = fromVectorT sh . vFromListN (product sh)
272289

273290
-- Index into the outermost dimension of an array.
274291
{-# INLINE indexT #-}
@@ -373,7 +390,7 @@ reverseT rs sh (T ats ao v) = T rts ro v
373390
{-# INLINE reduceT #-}
374391
reduceT :: (Vector v, VecElem v a) =>
375392
ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
376-
reduceT sh f z = scalarT . vFold f z . toVectorT sh
393+
reduceT sh f z = scalarT . foldl' (vFold f) z . toVectorListT sh
377394

378395
-- Right fold via toListT.
379396
{-# INLINE foldrT #-}
@@ -389,13 +406,14 @@ traverseT
389406
traverseT sh f a = fmap (fromListT sh) (traverse f (toListT sh a))
390407

391408
-- Fast check if all elements are equal.
409+
{-# INLINABLE allSameT #-}
392410
allSameT :: (Vector v, VecElem v a, Eq a) => ShapeL -> T v a -> Bool
393411
allSameT sh t@(T _ _ v)
394412
| vLength v <= 1 = True
395413
| otherwise =
396-
let !v' = toVectorT sh t
397-
!x = vIndex v' 0
398-
in vAll (x ==) v'
414+
let !l = toVectorListT sh t
415+
!x = vIndex (l !! 0) 0
416+
in all (vAll (x ==)) l
399417

400418
newtype Rect = Rect { unRect :: [String] } -- A rectangle of text
401419

@@ -482,10 +500,11 @@ zipWithLong2 :: (a -> b -> b) -> [a] -> [b] -> [b]
482500
zipWithLong2 f (a:as) (b:bs) = f a b : zipWithLong2 f as bs
483501
zipWithLong2 _ _ bs = bs
484502

503+
{-# INLINABLE padT #-}
485504
padT :: forall v a . (Vector v, VecElem v a) => a -> [(Int, Int)] -> ShapeL -> T v a -> ([Int], T v a)
486505
padT v aps ash at = (ss, fromVectorT ss $ vConcat $ pad' aps ash st at)
487506
where pad' :: [(Int, Int)] -> ShapeL -> [Int] -> T v a -> [v a]
488-
pad' [] sh _ t = [toVectorT sh t]
507+
pad' [] sh _ t = toVectorListT sh t
489508
pad' ((l,h):ps) (s:sh) (n:ns) t =
490509
[vReplicate (n*l) v] ++ concatMap (pad' ps sh ns . indexT t) [0..s-1] ++ [vReplicate (n*h) v]
491510
pad' _ _ _ _ = error $ "pad: rank mismatch " ++ show (length aps, length ash)
@@ -513,30 +532,30 @@ simpleReshape _ _ _ = Nothing
513532
-- Note: assumes + is commutative&associative.
514533
{-# INLINE sumT #-}
515534
sumT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
516-
sumT sh = vSum . toUnorderedVectorT sh
535+
sumT sh = sum . map vSum . toUnorderedVectorListT sh
517536

518537
-- Note: assumes * is commutative&associative.
519538
{-# INLINE productT #-}
520539
productT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
521-
productT sh = vProduct . toUnorderedVectorT sh
540+
productT sh = product . map vProduct . toUnorderedVectorListT sh
522541

523542
-- Note: assumes max is commutative&associative.
524543
{-# INLINE maximumT #-}
525544
maximumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
526-
maximumT sh = vMaximum . toUnorderedVectorT sh
545+
maximumT sh = maximum . map vMaximum . toUnorderedVectorListT sh
527546

528547
-- Note: assumes min is commutative&associative.
529548
{-# INLINE minimumT #-}
530549
minimumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
531-
minimumT sh = vMinimum . toUnorderedVectorT sh
550+
minimumT sh = minimum . map vMinimum . toUnorderedVectorListT sh
532551

533552
{-# INLINE anyT #-}
534553
anyT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
535-
anyT sh p = vAny p . toUnorderedVectorT sh
554+
anyT sh p = or . map (vAny p) . toUnorderedVectorListT sh
536555

537556
{-# INLINE allT #-}
538557
allT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
539-
allT sh p = vAll p . toUnorderedVectorT sh
558+
allT sh p = and . map (vAll p) . toUnorderedVectorListT sh
540559

541560
{-# INLINE updateT #-}
542561
updateT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [([Int], a)] -> T v a
@@ -563,13 +582,15 @@ iotaT n = fromListT [n] [0 .. fromIntegral n - 1] -- TODO: should use V.enumF
563582
-------
564583

565584
-- | Permute the elements of a list, the first argument is indices into the original list.
585+
{-# INLINE permute #-}
566586
permute :: [Int] -> [a] -> [a]
567587
permute is xs = map (xs!!) is
568588

569589
-- | Like 'dropWhile' but at the end of the list.
570590
revDropWhile :: (a -> Bool) -> [a] -> [a]
571591
revDropWhile p = reverse . dropWhile p . reverse
572592

593+
{-# INLINABLE allSame #-}
573594
allSame :: (Eq a) => [a] -> Bool
574595
allSame [] = True
575596
allSame (x : xs) = all (x ==) xs

Data/Array/Internal/Dynamic.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ instance Vector V.Vector where
6565
vToList = V.toList
6666
{-# INLINE vFromList #-}
6767
vFromList = V.fromList
68+
{-# INLINE vFromListN #-}
69+
vFromListN = V.fromListN
6870
{-# INLINE vSingleton #-}
6971
vSingleton = V.singleton
7072
{-# INLINE vReplicate #-}

Data/Array/Internal/DynamicG.hs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
-- limitations under the License.
1414

1515
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
16-
{-# LANGUAGE DeriveDataTypeable #-}
17-
{-# LANGUAGE DeriveGeneric #-}
18-
{-# LANGUAGE FlexibleInstances #-}
16+
{-# LANGUAGE DeriveDataTypeable #-}
17+
{-# LANGUAGE DeriveGeneric #-}
18+
{-# LANGUAGE FlexibleInstances #-}
1919
{-# LANGUAGE MultiParamTypeClasses #-}
20-
{-# LANGUAGE RoleAnnotations #-}
21-
{-# LANGUAGE ScopedTypeVariables #-}
22-
{-# LANGUAGE UndecidableInstances #-}
20+
{-# LANGUAGE RoleAnnotations #-}
21+
{-# LANGUAGE ScopedTypeVariables #-}
22+
{-# LANGUAGE UndecidableInstances #-}
2323
-- | Arrays of dynamic size. The arrays are polymorphic in the underlying
2424
-- linear data structure used to store the actual values.
2525
module Data.Array.Internal.DynamicG(
@@ -43,16 +43,16 @@ module Data.Array.Internal.DynamicG(
4343
update,
4444
generate, iterateN, iota,
4545
) where
46-
import Control.DeepSeq
47-
import Control.Monad(replicateM)
48-
import Data.Data(Data)
49-
import Data.List(sort)
50-
import GHC.Generics(Generic)
51-
import GHC.Stack
52-
import Test.QuickCheck hiding (generate)
53-
import Text.PrettyPrint.HughesPJClass hiding ((<>))
46+
import Control.DeepSeq
47+
import Control.Monad (replicateM)
48+
import Data.Data (Data)
49+
import Data.List (sort)
50+
import GHC.Generics (Generic)
51+
import GHC.Stack
52+
import Test.QuickCheck hiding (generate)
53+
import Text.PrettyPrint.HughesPJClass hiding ((<>))
5454

55-
import Data.Array.Internal
55+
import Data.Array.Internal
5656

5757
-- | Arrays stored in a /v/ with values of type /a/.
5858
type role Array representational nominal
@@ -126,7 +126,7 @@ toVector (A sh t) = toVectorT sh t
126126
{-# INLINE fromList #-}
127127
fromList :: (HasCallStack, Vector v, VecElem v a) => ShapeL -> [a] -> Array v a
128128
fromList ss vs | n /= l = error $ "fromList: size mismatch " ++ show (n, l)
129-
| otherwise = A ss $ T st 0 $ vFromList vs
129+
| otherwise = A ss $ T st 0 $ vFromListN l vs
130130
where n : st = getStridesT ss
131131
l = length vs
132132

@@ -190,7 +190,7 @@ scalar = A [] . scalarT
190190
{-# INLINE unScalar #-}
191191
unScalar :: (HasCallStack, Vector v, VecElem v a) => Array v a -> a
192192
unScalar (A [] t) = unScalarT t
193-
unScalar _ = error "unScalar: not a scalar"
193+
unScalar _ = error "unScalar: not a scalar"
194194

195195
-- | Make an array with all elements having the same value.
196196
-- O(1) time
@@ -330,8 +330,8 @@ window aws (A ash (T ss o v)) = A (win aws ash) (T (ss' ++ ss) o v)
330330
stride :: (HasCallStack, Vector v) => [Int] -> Array v a -> Array v a
331331
stride ats (A ash (T ss o v)) = A (str ats ash) (T (zipWith (*) (ats ++ repeat 1) ss) o v)
332332
where str (t:ts) (s:sh) = (s+t-1) `quot` t : str ts sh
333-
str [] sh = sh
334-
str _ _ = error $ "stride: rank mismatch " ++ show (ats, ash)
333+
str [] sh = sh
334+
str _ _ = error $ "stride: rank mismatch " ++ show (ats, ash)
335335

336336
-- | Rotate the array k times along the d'th dimension.
337337
-- E.g., if the array shape is @[2, 3, 2]@, d is 1, and k is 4,
@@ -490,7 +490,7 @@ broadcast ds sh a | length ds /= rank a = error "broadcast: wrong number of broa
490490
where r = length sh
491491
rsh = [ if i `elem` ds then s else 1 | (i, s) <- zip [0..] sh ]
492492
ascending (x:y:ys) = x < y && ascending (y:ys)
493-
ascending _ = True
493+
ascending _ = True
494494

495495
-- | Update the array at the specified indicies to the associated value.
496496
update :: (HasCallStack, Vector v, VecElem v a) =>

Data/Array/Internal/DynamicS.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ instance Vector V.Vector where
7272
vToList = V.toList
7373
{-# INLINE vFromList #-}
7474
vFromList = V.fromList
75+
{-# INLINE vFromListN #-}
76+
vFromListN = V.fromListN
7577
{-# INLINE vSingleton #-}
7678
vSingleton = V.singleton
7779
{-# INLINE vReplicate #-}

Data/Array/Internal/DynamicU.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ instance Vector V.Vector where
6969
vToList = V.toList
7070
{-# INLINE vFromList #-}
7171
vFromList = V.fromList
72+
{-# INLINE vFromListN #-}
73+
vFromListN = V.fromListN
7274
{-# INLINE vSingleton #-}
7375
vSingleton = V.singleton
7476
{-# INLINE vReplicate #-}

Data/Array/Internal/RankedG.hs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ data Array (n :: Nat) v a = A !ShapeL !(T v a)
6969
deriving (Generic, Data)
7070

7171
instance (Vector v, Show a, VecElem v a) => Show (Array n v a) where
72-
{-# INLINABLE showsPrec #-}
7372
showsPrec p a@(A s _) = showParen (p > 10) $
7473
showString "fromList " . showsPrec 11 s . showString " " . showsPrec 11 (toList a)
74+
{-# INLINABLE showsPrec #-}
7575

7676
instance (KnownNat n, Vector v, Read a, VecElem v a) => Read (Array n v a) where
7777
readsPrec p = readParen (p > 10) $ \ r1 ->
@@ -92,6 +92,7 @@ instance (Vector v, Pretty a, VecElem v a) => Pretty (Array n v a) where
9292

9393
instance (NFData (v a)) => NFData (Array n v a) where
9494
rnf (A sh v) = rnf sh `seq` rnf v
95+
{-# INLINE rnf #-}
9596

9697
-- | The number of elements in the array.
9798
-- O(1) time.
@@ -141,7 +142,7 @@ fromList :: forall n v a . (HasCallStack, Vector v, VecElem v a, KnownNat n) =>
141142
ShapeL -> [a] -> Array n v a
142143
fromList ss vs | n /= l = error $ "fromList: size mismatch " ++ show (n, l)
143144
| length ss /= valueOf @n = error $ "fromList: rank mismatch " ++ show (length ss, valueOf @n :: Int)
144-
| otherwise = A ss $ T st 0 $ vFromList vs
145+
| otherwise = A ss $ T st 0 $ vFromListN l vs
145146
where n : st = getStridesT ss
146147
l = length vs
147148

@@ -218,7 +219,7 @@ constant :: forall n v a . (Vector v, VecElem v a, KnownNat n) =>
218219
ShapeL -> a -> Array n v a
219220
constant sh | badShape sh = error $ "constant: bad shape: " ++ show sh
220221
| length sh /= valueOf @n = error "constant: rank mismatch"
221-
| otherwise = A sh . constantT sh
222+
| otherwise = A sh . constantT sh
222223

223224
-- | Map over the array elements.
224225
-- O(n) time.
@@ -344,6 +345,7 @@ stride ats (A ash (T ss o v)) = A (str ats ash) (T (zipWith (*) (ats ++ repeat 1
344345
-- | Rotate the array k times along the d'th dimension.
345346
-- E.g., if the array shape is @[2, 3, 2]@, d is 1, and k is 4,
346347
-- the resulting shape will be @[2, 4, 3, 2]@.
348+
{-# INLINABLE rotate #-}
347349
rotate :: forall d p v a.
348350
(KnownNat p, KnownNat d,
349351
Vector v, VecElem v a,
@@ -455,6 +457,7 @@ traverseA
455457
traverseA f (A sh t) = A sh <$> traverseT sh f t
456458

457459
-- | Check if all elements of the array are equal.
460+
{-# INLINE allSameA #-}
458461
allSameA :: (Vector v, VecElem v a, Eq a) => Array r v a -> Bool
459462
allSameA (A sh t) = allSameT sh t
460463

@@ -500,6 +503,7 @@ allA p (A sh t) = allT sh p t
500503
-- and just replicate the data along all other dimensions.
501504
-- The list of dimensions indicies must have the same rank as the argument array
502505
-- and it must be strictly ascending.
506+
{-# INLINABLE broadcast #-}
503507
broadcast :: forall r' r v a .
504508
(HasCallStack, Vector v, VecElem v a, KnownNat r, KnownNat r') =>
505509
[Int] -> ShapeL -> Array r v a -> Array r' v a

0 commit comments

Comments
 (0)