Skip to content

Commit 73e13df

Browse files
authored
Merge pull request #194 from treeowl/strict-traverse
Make Strict.traverseWithKey actually strict
2 parents cba2e43 + d0dcc17 commit 73e13df

File tree

5 files changed

+133
-21
lines changed

5 files changed

+133
-21
lines changed

CHANGES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
* Add `HashMap.keysSet`.
66

7+
* Make `HashMap.Strict.traverseWithKey` force the results before
8+
installing them in the map.
9+
710
## 0.2.9.0
811

912
* Add `Ord/Ord1/Ord2` instances. (Thanks, Oleg Grenrus)

Data/HashMap/Array.hs

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{-# LANGUAGE BangPatterns, CPP, MagicHash, Rank2Types, UnboxedTuples #-}
1+
{-# LANGUAGE BangPatterns, CPP, MagicHash, Rank2Types, UnboxedTuples, ScopedTypeVariables #-}
22
{-# OPTIONS_GHC -fno-full-laziness -funbox-strict-fields #-}
33

44
-- | Zero based arrays.
@@ -46,17 +46,20 @@ module Data.HashMap.Array
4646
, map
4747
, map'
4848
, traverse
49+
, traverse'
4950
, filter
5051
, toList
52+
, fromList
5153
) where
5254

53-
import qualified Data.Traversable as Traversable
54-
#if __GLASGOW_HASKELL__ < 709
55-
import Control.Applicative (Applicative)
55+
#if !MIN_VERSION_base(4,8,0)
56+
import Control.Applicative (Applicative (..), (<$>))
5657
#endif
58+
import Control.Applicative (liftA2)
5759
import Control.DeepSeq
58-
import GHC.Exts(Int(..), Int#, reallyUnsafePtrEquality#, tagToEnum#, unsafeCoerce#, State#)
60+
import GHC.Exts(Int(..), Int#, reallyUnsafePtrEquality#, tagToEnum#, unsafeCoerce#, State#, (+#))
5961
import GHC.ST (ST(..))
62+
import Control.Monad.ST (stToIO)
6063

6164
#if __GLASGOW_HASKELL__ >= 709
6265
import Prelude hiding (filter, foldr, length, map, read, traverse)
@@ -473,18 +476,36 @@ fromList n xs0 =
473476
toList :: Array a -> [a]
474477
toList = foldr (:) []
475478

476-
traverse :: Applicative f => (a -> f b) -> Array a -> f (Array b)
477-
traverse f = \ ary -> fromList (length ary) `fmap`
478-
Traversable.traverse f (toList ary)
479+
newtype STA a = STA {_runSTA :: forall s. MutableArray# s a -> ST s (Array a)}
480+
481+
runSTA :: Int -> STA a -> Array a
482+
runSTA !n (STA m) = runST $ new_ n >>= \ (MArray ar) -> m ar
483+
484+
traverse :: forall f a b. Applicative f => (a -> f b) -> Array a -> f (Array b)
485+
traverse f = \ !ary -> runSTA (length ary) <$> foldr go stop ary 0#
486+
where
487+
go :: a -> (Int# -> f (STA b)) -> Int# -> f (STA b)
488+
go a r i = liftA2 (\b (STA m) -> STA $ \mry# -> write (MArray mry#) (I# i) b >> m mry#) (f a) (r (i +# 1#))
489+
stop :: Int# -> f (STA b)
490+
stop _i = pure (STA (\mry# -> unsafeFreeze (MArray mry#)))
479491
{-# INLINE [1] traverse #-}
480492

481-
-- Traversing in ST, we don't need to make a list; we
493+
traverse' :: forall f a b. Applicative f => (a -> f b) -> Array a -> f (Array b)
494+
traverse' f = \ !ary -> runSTA (length ary) <$> foldr go stop ary 0#
495+
where
496+
go :: a -> (Int# -> f (STA b)) -> Int# -> f (STA b)
497+
go a r i = liftA2 (\ !b (STA m) -> STA $ \mry# -> write (MArray mry#) (I# i) b >> m mry#) (f a) (r (i +# 1#))
498+
stop :: Int# -> f (STA b)
499+
stop _i = pure (STA (\mry# -> unsafeFreeze (MArray mry#)))
500+
{-# INLINE [1] traverse' #-}
501+
502+
-- Traversing in ST, we don't need to get fancy; we
482503
-- can just do it directly.
483504
traverseST :: (a -> ST s b) -> Array a -> ST s (Array b)
484505
traverseST f = \ ary0 ->
485506
let
486507
!len = length ary0
487-
go k mary
508+
go k !mary
488509
| k == len = return mary
489510
| otherwise = do
490511
x <- indexM ary0 k
@@ -494,8 +515,59 @@ traverseST f = \ ary0 ->
494515
in new_ len >>= (go 0 >=> unsafeFreeze)
495516
{-# INLINE traverseST #-}
496517

518+
traverseIO :: (a -> IO b) -> Array a -> IO (Array b)
519+
traverseIO f = \ ary0 ->
520+
let
521+
!len = length ary0
522+
go k !mary
523+
| k == len = return mary
524+
| otherwise = do
525+
x <- stToIO $ indexM ary0 k
526+
y <- f x
527+
stToIO $ write mary k y
528+
go (k + 1) mary
529+
in stToIO (new_ len) >>= (go 0 >=> stToIO . unsafeFreeze)
530+
{-# INLINE traverseIO #-}
531+
532+
497533
{-# RULES
498534
"traverse/ST" forall f. traverse f = traverseST f
535+
"traverse/IO" forall f. traverse f = traverseIO f
536+
#-}
537+
538+
-- Traversing in ST, we don't need to get fancy; we
539+
-- can just do it directly.
540+
traverseST' :: (a -> ST s b) -> Array a -> ST s (Array b)
541+
traverseST' f = \ ary0 ->
542+
let
543+
!len = length ary0
544+
go k !mary
545+
| k == len = return mary
546+
| otherwise = do
547+
x <- indexM ary0 k
548+
!y <- f x
549+
write mary k y
550+
go (k + 1) mary
551+
in new_ len >>= (go 0 >=> unsafeFreeze)
552+
{-# INLINE traverseST' #-}
553+
554+
traverseIO' :: (a -> IO b) -> Array a -> IO (Array b)
555+
traverseIO' f = \ ary0 ->
556+
let
557+
!len = length ary0
558+
go k !mary
559+
| k == len = return mary
560+
| otherwise = do
561+
x <- stToIO $ indexM ary0 k
562+
!y <- f x
563+
stToIO $ write mary k y
564+
go (k + 1) mary
565+
in stToIO (new_ len) >>= (go 0 >=> stToIO . unsafeFreeze)
566+
{-# INLINE traverseIO' #-}
567+
568+
{-# RULES
569+
"traverse'/ST" forall f. traverse' f = traverseST' f
570+
"traverse'/IO" forall f. traverse' f = traverseIO' f
499571
#-}
500572

501573
filter :: (a -> Bool) -> Array a -> Array a

Data/HashMap/Base.hs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ infixl 9 !
634634

635635
-- | Create a 'Collision' value with two 'Leaf' values.
636636
collision :: Hash -> Leaf k v -> Leaf k v -> HashMap k v
637-
collision h e1 e2 =
637+
collision h !e1 !e2 =
638638
let v = A.run $ do mary <- A.new 2 e1
639639
A.write mary 1 e2
640640
return mary
@@ -1432,18 +1432,25 @@ map f = mapWithKey (const f)
14321432
-- TODO: We should be able to use mutation to create the new
14331433
-- 'HashMap'.
14341434

1435-
-- | /O(n)/ Transform this map by accumulating an Applicative result
1436-
-- from every value.
1437-
traverseWithKey :: Applicative f => (k -> v1 -> f v2) -> HashMap k v1
1438-
-> f (HashMap k v2)
1435+
-- | /O(n)/ Perform an 'Applicative' action for each key-value pair
1436+
-- in a 'HashMap' and produce a 'HashMap' of all the results.
1437+
--
1438+
-- Note: the order in which the actions occur is unspecified. In particular,
1439+
-- when the map contains hash collisions, the order in which the actions
1440+
-- associated with the keys involved will depend in an unspecified way on
1441+
-- their insertion order.
1442+
traverseWithKey
1443+
:: Applicative f
1444+
=> (k -> v1 -> f v2)
1445+
-> HashMap k v1 -> f (HashMap k v2)
14391446
traverseWithKey f = go
14401447
where
14411448
go Empty = pure Empty
14421449
go (Leaf h (L k v)) = Leaf h . L k <$> f k v
14431450
go (BitmapIndexed b ary) = BitmapIndexed b <$> A.traverse go ary
14441451
go (Full ary) = Full <$> A.traverse go ary
14451452
go (Collision h ary) =
1446-
Collision h <$> A.traverse (\ (L k v) -> L k <$> f k v) ary
1453+
Collision h <$> A.traverse' (\ (L k v) -> L k <$> f k v) ary
14471454
{-# INLINE traverseWithKey #-}
14481455

14491456
------------------------------------------------------------------------

Data/HashMap/Strict/Base.hs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ module Data.HashMap.Strict.Base
9393
import Data.Bits ((.&.), (.|.))
9494

9595
#if !MIN_VERSION_base(4,8,0)
96-
import Data.Functor((<$>))
96+
import Control.Applicative (Applicative (..), (<$>))
9797
#endif
9898
import qualified Data.List as L
9999
import Data.Hashable (Hashable)
@@ -104,7 +104,8 @@ import qualified Data.HashMap.Base as HM
104104
import Data.HashMap.Base hiding (
105105
alter, alterF, adjust, fromList, fromListWith, insert, insertWith,
106106
differenceWith, intersectionWith, intersectionWithKey, map, mapWithKey,
107-
mapMaybe, mapMaybeWithKey, singleton, update, unionWith, unionWithKey)
107+
mapMaybe, mapMaybeWithKey, singleton, update, unionWith, unionWithKey,
108+
traverseWithKey)
108109
import Data.HashMap.Unsafe (runST)
109110
#if MIN_VERSION_base(4,8,0)
110111
import Data.Functor.Identity
@@ -522,8 +523,31 @@ mapMaybe :: (v1 -> Maybe v2) -> HashMap k v1 -> HashMap k v2
522523
mapMaybe f = mapMaybeWithKey (const f)
523524
{-# INLINE mapMaybe #-}
524525

525-
526-
-- TODO: Should we add a strict traverseWithKey?
526+
-- | /O(n)/ Perform an 'Applicative' action for each key-value pair
527+
-- in a 'HashMap' and produce a 'HashMap' of all the results. Each 'HashMap'
528+
-- will be strict in all its values.
529+
--
530+
-- @
531+
-- traverseWithKey f = fmap ('map' id) . "Data.HashMap.Lazy".'Data.HashMap.Lazy.traverseWithKey' f
532+
-- @
533+
--
534+
-- Note: the order in which the actions occur is unspecified. In particular,
535+
-- when the map contains hash collisions, the order in which the actions
536+
-- associated with the keys involved will depend in an unspecified way on
537+
-- their insertion order.
538+
traverseWithKey
539+
:: Applicative f
540+
=> (k -> v1 -> f v2)
541+
-> HashMap k v1 -> f (HashMap k v2)
542+
traverseWithKey f = go
543+
where
544+
go Empty = pure Empty
545+
go (Leaf h (L k v)) = leaf h k <$> f k v
546+
go (BitmapIndexed b ary) = BitmapIndexed b <$> A.traverse' go ary
547+
go (Full ary) = Full <$> A.traverse' go ary
548+
go (Collision h ary) =
549+
Collision h <$> A.traverse' (\ (L k v) -> (L k $!) <$> f k v) ary
550+
{-# INLINE traverseWithKey #-}
527551

528552
------------------------------------------------------------------------
529553
-- * Difference and intersection
@@ -643,5 +667,5 @@ updateOrSnocWithKey f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
643667
-- inserted into the constructor.
644668

645669
leaf :: Hash -> k -> v -> HashMap k v
646-
leaf h k !v = Leaf h (L k v)
670+
leaf h k = \ !v -> Leaf h (L k v)
647671
{-# INLINE leaf #-}

tests/HashMapProperties.hs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ pUnions xss = M.toAscList (M.unions (map M.fromList xss)) ==
246246
pMap :: [(Key, Int)] -> Bool
247247
pMap = M.map (+ 1) `eq_` HM.map (+ 1)
248248

249+
pTraverse :: [(Key, Int)] -> Bool
250+
pTraverse xs =
251+
L.sort (fmap (L.sort . M.toList) (M.traverseWithKey (\_ v -> [v + 1, v + 2]) (M.fromList (take 10 xs))))
252+
== L.sort (fmap (L.sort . HM.toList) (HM.traverseWithKey (\_ v -> [v + 1, v + 2]) (HM.fromList (take 10 xs))))
253+
249254
------------------------------------------------------------------------
250255
-- ** Difference and intersection
251256

@@ -382,6 +387,7 @@ tests =
382387
, testProperty "unions" pUnions
383388
-- Transformations
384389
, testProperty "map" pMap
390+
, testProperty "traverse" pTraverse
385391
-- Folds
386392
, testGroup "folds"
387393
[ testProperty "foldr" pFoldr

0 commit comments

Comments
 (0)