@@ -54,6 +54,7 @@ class Vector v where
5454 vLength :: (VecElem v a ) => v a -> Int
5555 vToList :: (VecElem v a ) => v a -> [a ]
5656 vFromList :: (VecElem v a ) => [a ] -> v a
57+ vFromListN :: (VecElem v a ) => Int -> [a ] -> v a
5758 vSingleton :: (VecElem v a ) => a -> v a
5859 vReplicate :: (VecElem v a ) => Int -> a -> v a
5960 vMap :: (VecElem v a , VecElem v b ) => (a -> b ) -> v a -> v b
@@ -84,6 +85,7 @@ instance Vector [] where
8485 vLength = length
8586 vToList = id
8687 vFromList = id
88+ vFromListN _ = id
8789 vSingleton = pure
8890 vReplicate = replicate
8991 vMap = map
@@ -209,41 +211,50 @@ unScalarT (T _ o v) = vIndex v o
209211constantT :: (Vector v , VecElem v a ) => ShapeL -> a -> T v a
210212constantT sh x = T (map (const 0 ) sh) 0 (vSingleton x)
211213
212- -- TODO: change to return a list of vectors.
213- -- Convert an array to a vector in the natural order.
214- {-# INLINE toVectorT #-}
215- toVectorT :: (Vector v , VecElem v a ) => ShapeL -> T v a -> v a
216- toVectorT sh a@ (T ats ao v) =
214+ -- Convert an array to a list of vectors, which together contain
215+ -- all the elements in the natural order.
216+ -- An invariant: if the input array is non-empty the returned list
217+ -- will have no empty vectors.
218+ -- The minimum/maximum operations rely on this invariant.
219+ {-# INLINE toVectorListT #-}
220+ toVectorListT :: (Vector v , VecElem v a ) => ShapeL -> T v a -> [v a ]
221+ toVectorListT sh a@ (T ats ao v) =
217222 let l : ts' = getStridesT sh
218223 -- Are strides ok from this point?
219224 oks = scanr (&&) True (zipWith (==) ats ts')
220- loop _ [] _ o =
221- DL. singleton (vSlice o 1 v)
222- loop (b: bs) (s: ss) (t: ts) o =
225+ loop (b: bs) (s: ss) (t: ts) ! o =
223226 if b then
224227 -- All strides normal from this point,
225228 -- so just take a slice of the underlying vector.
226229 DL. singleton (vSlice o (s* t) v)
227230 else
228231 -- Strides are not normal, collect slices.
229232 DL. concat [ loop bs ss ts (i* t + o) | i <- [0 .. s- 1 ] ]
230- loop _ _ _ _ = error " impossible"
231- in if head oks && vLength v == l then
233+ loop _ _ _ _ = error " impossible" -- due to how @loop@ is called
234+ in if ats == ts' && vLength v == l then
232235 -- All strides are normal, return entire vector
233- v
234- else if oks !! length sh then -- Special case for speed.
236+ [v]
237+ else if null sh then
238+ [vSlice ao 1 v]
239+ else if oks !! (length sh - 1 ) then -- Special case for speed.
235240 -- Innermost dimension is normal, so slices are non-trivial.
236- vConcat $ DL. toList $ loop oks sh ats ao
241+ DL. toList $ loop oks sh ats ao
237242 else
238243 -- All slices would have length 1, going via a list is faster.
239- vFromList $ toListT sh a
244+ [vFromListN l $ toListT sh a]
245+
246+ {-# INLINE toVectorT #-}
247+ toVectorT :: (Vector v , VecElem v a ) => ShapeL -> T v a -> v a
248+ toVectorT sh a = case toVectorListT sh a of
249+ [v] -> v
250+ l -> vConcat l
240251
241- -- Convert to a vector containing the right elements,
252+ -- Convert to a list of vectors containing altogether the right elements,
242253-- but not necessarily in the right order.
243254-- This is used for reduction with commutative&associative operations.
244- {-# INLINE toUnorderedVectorT #-}
245- toUnorderedVectorT :: (Vector v , VecElem v a ) => ShapeL -> T v a -> v a
246- toUnorderedVectorT sh a@ (T ats ao v) =
255+ {-# INLINE toUnorderedVectorListT #-}
256+ toUnorderedVectorListT :: (Vector v , VecElem v a ) => ShapeL -> T v a -> [ v a ]
257+ toUnorderedVectorListT sh a@ (T ats ao v) =
247258 -- Figure out if the array maps onto some contiguous slice of the vector.
248259 -- Do this by checking if a transposition of the array corresponds to
249260 -- normal strides.
@@ -256,9 +267,15 @@ toUnorderedVectorT sh a@(T ats ao v) =
256267 l : ts' = getStridesT sh'
257268 in
258269 if ats' == ts' then
259- vSlice ao l v
270+ [ vSlice ao l v]
260271 else
261- toVectorT sh a
272+ toVectorListT sh a
273+
274+ {-# INLINE toUnorderedVectorT #-}
275+ toUnorderedVectorT :: (Vector v , VecElem v a ) => ShapeL -> T v a -> v a
276+ toUnorderedVectorT sh a = case toUnorderedVectorListT sh a of
277+ [v] -> v
278+ l -> vConcat l
262279
263280-- Convert from a vector.
264281{-# INLINE fromVectorT #-}
@@ -268,7 +285,7 @@ fromVectorT sh = T (tail $ getStridesT sh) 0
268285-- Convert from a list
269286{-# INLINE fromListT #-}
270287fromListT :: (Vector v , VecElem v a ) => [Int ] -> [a ] -> T v a
271- fromListT sh = fromVectorT sh . vFromList
288+ fromListT sh = fromVectorT sh . vFromListN ( product sh)
272289
273290-- Index into the outermost dimension of an array.
274291{-# INLINE indexT #-}
@@ -373,7 +390,7 @@ reverseT rs sh (T ats ao v) = T rts ro v
373390{-# INLINE reduceT #-}
374391reduceT :: (Vector v , VecElem v a ) =>
375392 ShapeL -> (a -> a -> a ) -> a -> T v a -> T v a
376- reduceT sh f z = scalarT . vFold f z . toVectorT sh
393+ reduceT sh f z = scalarT . foldl' ( vFold f) z . toVectorListT sh
377394
378395-- Right fold via toListT.
379396{-# INLINE foldrT #-}
@@ -389,13 +406,14 @@ traverseT
389406traverseT sh f a = fmap (fromListT sh) (traverse f (toListT sh a))
390407
391408-- Fast check if all elements are equal.
409+ {-# INLINABLE allSameT #-}
392410allSameT :: (Vector v , VecElem v a , Eq a ) => ShapeL -> T v a -> Bool
393411allSameT sh t@ (T _ _ v)
394412 | vLength v <= 1 = True
395413 | otherwise =
396- let ! v' = toVectorT sh t
397- ! x = vIndex v' 0
398- in vAll (x == ) v'
414+ let ! l = toVectorListT sh t
415+ ! x = vIndex (l !! 0 ) 0
416+ in all ( vAll (x == )) l
399417
400418newtype Rect = Rect { unRect :: [String ] } -- A rectangle of text
401419
@@ -482,10 +500,11 @@ zipWithLong2 :: (a -> b -> b) -> [a] -> [b] -> [b]
482500zipWithLong2 f (a: as) (b: bs) = f a b : zipWithLong2 f as bs
483501zipWithLong2 _ _ bs = bs
484502
503+ {-# INLINABLE padT #-}
485504padT :: forall v a . (Vector v , VecElem v a ) => a -> [(Int , Int )] -> ShapeL -> T v a -> ([Int ], T v a )
486505padT v aps ash at = (ss, fromVectorT ss $ vConcat $ pad' aps ash st at)
487506 where pad' :: [(Int , Int )] -> ShapeL -> [Int ] -> T v a -> [v a ]
488- pad' [] sh _ t = [toVectorT sh t]
507+ pad' [] sh _ t = toVectorListT sh t
489508 pad' ((l,h): ps) (s: sh) (n: ns) t =
490509 [vReplicate (n* l) v] ++ concatMap (pad' ps sh ns . indexT t) [0 .. s- 1 ] ++ [vReplicate (n* h) v]
491510 pad' _ _ _ _ = error $ " pad: rank mismatch " ++ show (length aps, length ash)
@@ -513,30 +532,30 @@ simpleReshape _ _ _ = Nothing
513532-- Note: assumes + is commutative&associative.
514533{-# INLINE sumT #-}
515534sumT :: (Vector v , VecElem v a , Num a ) => ShapeL -> T v a -> a
516- sumT sh = vSum . toUnorderedVectorT sh
535+ sumT sh = sum . map vSum . toUnorderedVectorListT sh
517536
518537-- Note: assumes * is commutative&associative.
519538{-# INLINE productT #-}
520539productT :: (Vector v , VecElem v a , Num a ) => ShapeL -> T v a -> a
521- productT sh = vProduct . toUnorderedVectorT sh
540+ productT sh = product . map vProduct . toUnorderedVectorListT sh
522541
523542-- Note: assumes max is commutative&associative.
524543{-# INLINE maximumT #-}
525544maximumT :: (Vector v , VecElem v a , Ord a ) => ShapeL -> T v a -> a
526- maximumT sh = vMaximum . toUnorderedVectorT sh
545+ maximumT sh = maximum . map vMaximum . toUnorderedVectorListT sh
527546
528547-- Note: assumes min is commutative&associative.
529548{-# INLINE minimumT #-}
530549minimumT :: (Vector v , VecElem v a , Ord a ) => ShapeL -> T v a -> a
531- minimumT sh = vMinimum . toUnorderedVectorT sh
550+ minimumT sh = minimum . map vMinimum . toUnorderedVectorListT sh
532551
533552{-# INLINE anyT #-}
534553anyT :: (Vector v , VecElem v a ) => ShapeL -> (a -> Bool ) -> T v a -> Bool
535- anyT sh p = vAny p . toUnorderedVectorT sh
554+ anyT sh p = or . map ( vAny p) . toUnorderedVectorListT sh
536555
537556{-# INLINE allT #-}
538557allT :: (Vector v , VecElem v a ) => ShapeL -> (a -> Bool ) -> T v a -> Bool
539- allT sh p = vAll p . toUnorderedVectorT sh
558+ allT sh p = and . map ( vAll p) . toUnorderedVectorListT sh
540559
541560{-# INLINE updateT #-}
542561updateT :: (Vector v , VecElem v a ) => ShapeL -> T v a -> [([Int ], a )] -> T v a
@@ -563,13 +582,15 @@ iotaT n = fromListT [n] [0 .. fromIntegral n - 1] -- TODO: should use V.enumF
563582-------
564583
565584-- | Permute the elements of a list, the first argument is indices into the original list.
585+ {-# INLINE permute #-}
566586permute :: [Int ] -> [a ] -> [a ]
567587permute is xs = map (xs!! ) is
568588
569589-- | Like 'dropWhile' but at the end of the list.
570590revDropWhile :: (a -> Bool ) -> [a ] -> [a ]
571591revDropWhile p = reverse . dropWhile p . reverse
572592
593+ {-# INLINABLE allSame #-}
573594allSame :: (Eq a ) => [a ] -> Bool
574595allSame [] = True
575596allSame (x : xs) = all (x == ) xs
0 commit comments