Skip to content

Commit cea2f74

Browse files
committed
Traverse arrays better
The previous approach was to convert the array to a list, traverse the list, and then put it back into an array. I figured it might be worth trying a more direct approach. It seems to work rather better. I also added rules to use a very direct implementation for `ST` and `IO`.
1 parent 9d67eab commit cea2f74

File tree

2 files changed

+55
-36
lines changed

2 files changed

+55
-36
lines changed

Data/HashMap/Array.hs

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{-# LANGUAGE BangPatterns, CPP, MagicHash, Rank2Types, UnboxedTuples #-}
1+
{-# LANGUAGE BangPatterns, CPP, MagicHash, Rank2Types, UnboxedTuples, ScopedTypeVariables #-}
22
{-# OPTIONS_GHC -fno-full-laziness -funbox-strict-fields #-}
33

44
-- | Zero based arrays.
@@ -57,8 +57,9 @@ import Control.Applicative (Applicative (..), (<$>))
5757
#endif
5858
import Control.Applicative (liftA2)
5959
import Control.DeepSeq
60-
import GHC.Exts(Int(..), Int#, reallyUnsafePtrEquality#, tagToEnum#, unsafeCoerce#, State#)
60+
import GHC.Exts(Int(..), Int#, reallyUnsafePtrEquality#, tagToEnum#, unsafeCoerce#, State#, (+#))
6161
import GHC.ST (ST(..))
62+
import Control.Monad.ST (stToIO)
6263

6364
#if __GLASGOW_HASKELL__ >= 709
6465
import Prelude hiding (filter, foldr, length, map, read, traverse)
@@ -475,49 +476,36 @@ fromList n xs0 =
475476
toList :: Array a -> [a]
476477
toList = foldr (:) []
477478

478-
data SList a = SCons !a (SList a) | SNil
479+
newtype STA a = STA {_runSTA :: forall s. MutableArray# s a -> ST s (Array a)}
479480

480-
traverseToSList
481-
:: Applicative f
482-
=> (a -> f b) -> [a] -> f (SList b)
483-
traverseToSList f = go
484-
where
485-
go (a : as) = liftA2 SCons (f a) (go as)
486-
go [] = pure SNil
487-
488-
_slength :: SList a -> Int
489-
_slength = go 0 where
490-
go !acc SNil = acc
491-
go acc (SCons _ xs) = go (acc + 1) xs
481+
runSTA :: Int -> STA a -> Array a
482+
runSTA !n (STA m) = runST $ new_ n >>= \ (MArray ar) -> m ar
492483

493-
fromSList :: Int -> SList a -> Array a
494-
fromSList n xs0 =
495-
CHECK_EQ("fromSList", n, _slength xs0)
496-
run $ do
497-
mary <- new_ n
498-
go xs0 mary 0
484+
traverse :: forall f a b. Applicative f => (a -> f b) -> Array a -> f (Array b)
485+
traverse f = \ !ary -> runSTA (length ary) <$> foldr go stop ary 0#
499486
where
500-
go SNil !mary !_ = return mary
501-
go (SCons x xs) mary i = do write mary i x
502-
go xs mary (i + 1)
503-
504-
traverse :: Applicative f => (a -> f b) -> Array a -> f (Array b)
505-
traverse f = \ ary -> fromList (length ary) `fmap`
506-
Traversable.traverse f (toList ary)
487+
go :: a -> (Int# -> f (STA b)) -> Int# -> f (STA b)
488+
go a r i = liftA2 (\b (STA m) -> STA $ \mry# -> write (MArray mry#) (I# i) b >> m mry#) (f a) (r (i +# 1#))
489+
stop :: Int# -> f (STA b)
490+
stop _i = pure (STA (\mry# -> unsafeFreeze (MArray mry#)))
507491
{-# INLINE [1] traverse #-}
508492

509-
traverse' :: Applicative f => (a -> f b) -> Array a -> f (Array b)
510-
traverse' f = \ary -> fromSList (length ary) `fmap`
511-
traverseToSList f (toList ary)
493+
traverse' :: forall f a b. Applicative f => (a -> f b) -> Array a -> f (Array b)
494+
traverse' f = \ !ary -> runSTA (length ary) <$> foldr go stop ary 0#
495+
where
496+
go :: a -> (Int# -> f (STA b)) -> Int# -> f (STA b)
497+
go a r i = liftA2 (\ !b (STA m) -> STA $ \mry# -> write (MArray mry#) (I# i) b >> m mry#) (f a) (r (i +# 1#))
498+
stop :: Int# -> f (STA b)
499+
stop _i = pure (STA (\mry# -> unsafeFreeze (MArray mry#)))
512500
{-# INLINE [1] traverse' #-}
513501

514-
-- Traversing in ST, we don't need to make a list; we
502+
-- Traversing in ST, we don't need to get fancy; we
515503
-- can just do it directly.
516504
traverseST :: (a -> ST s b) -> Array a -> ST s (Array b)
517505
traverseST f = \ ary0 ->
518506
let
519507
!len = length ary0
520-
go k mary
508+
go k !mary
521509
| k == len = return mary
522510
| otherwise = do
523511
x <- indexM ary0 k
@@ -527,17 +515,33 @@ traverseST f = \ ary0 ->
527515
in new_ len >>= (go 0 >=> unsafeFreeze)
528516
{-# INLINE traverseST #-}
529517

518+
traverseIO :: (a -> IO b) -> Array a -> IO (Array b)
519+
traverseIO f = \ ary0 ->
520+
let
521+
!len = length ary0
522+
go k !mary
523+
| k == len = return mary
524+
| otherwise = do
525+
x <- stToIO $ indexM ary0 k
526+
y <- f x
527+
stToIO $ write mary k y
528+
go (k + 1) mary
529+
in stToIO (new_ len) >>= (go 0 >=> stToIO . unsafeFreeze)
530+
{-# INLINE traverseIO #-}
531+
532+
530533
{-# RULES
531534
"traverse/ST" forall f. traverse f = traverseST f
535+
"traverse/IO" forall f. traverse f = traverseIO f
532536
#-}
533537

534-
-- Traversing in ST, we don't need to make a list; we
538+
-- Traversing in ST, we don't need to get fancy; we
535539
-- can just do it directly.
536540
traverseST' :: (a -> ST s b) -> Array a -> ST s (Array b)
537541
traverseST' f = \ ary0 ->
538542
let
539543
!len = length ary0
540-
go k mary
544+
go k !mary
541545
| k == len = return mary
542546
| otherwise = do
543547
x <- indexM ary0 k
@@ -547,8 +551,23 @@ traverseST' f = \ ary0 ->
547551
in new_ len >>= (go 0 >=> unsafeFreeze)
548552
{-# INLINE traverseST' #-}
549553

554+
traverseIO' :: (a -> IO b) -> Array a -> IO (Array b)
555+
traverseIO' f = \ ary0 ->
556+
let
557+
!len = length ary0
558+
go k !mary
559+
| k == len = return mary
560+
| otherwise = do
561+
x <- stToIO $ indexM ary0 k
562+
!y <- f x
563+
stToIO $ write mary k y
564+
go (k + 1) mary
565+
in stToIO (new_ len) >>= (go 0 >=> stToIO . unsafeFreeze)
566+
{-# INLINE traverseIO' #-}
567+
550568
{-# RULES
551569
"traverse'/ST" forall f. traverse' f = traverseST' f
570+
"traverse'/IO" forall f. traverse' f = traverseIO' f
552571
#-}
553572

554573
filter :: (a -> Bool) -> Array a -> Array a

Data/HashMap/Strict/Base.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ module Data.HashMap.Strict.Base
9393
import Data.Bits ((.&.), (.|.))
9494

9595
#if !MIN_VERSION_base(4,8,0)
96-
import Data.Functor((<$>))
96+
import Control.Applicative (Applicative (..), (<$>))
9797
#endif
9898
import qualified Data.List as L
9999
import Data.Hashable (Hashable)

0 commit comments

Comments
 (0)