Skip to content

Commit 83a36ee

Browse files
committed
Implement unionWithKey for HashMap
1 parent acb67e3 commit 83a36ee

File tree

4 files changed

+68
-22
lines changed

4 files changed

+68
-22
lines changed

Data/HashMap/Base.hs

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ module Data.HashMap.Base
3535
-- ** Union
3636
, union
3737
, unionWith
38+
, unionWithKey
3839
, unions
3940

4041
-- * Transformations
@@ -86,6 +87,7 @@ module Data.HashMap.Base
8687
, update16M
8788
, update16With'
8889
, updateOrConcatWith
90+
, updateOrConcatWithKey
8991
, filterMapAux
9092
) where
9193

@@ -656,25 +658,33 @@ union = unionWith const
656658
-- result.
657659
unionWith :: (Eq k, Hashable k) => (v -> v -> v) -> HashMap k v -> HashMap k v
658660
-> HashMap k v
659-
unionWith f = go 0
661+
unionWith f = unionWithKey (const f)
662+
{-# INLINE unionWith #-}
663+
664+
-- | /O(n+m)/ The union of two maps. If a key occurs in both maps,
665+
-- the provided function (first argument) will be used to compute the
666+
-- result.
667+
unionWithKey :: (Eq k, Hashable k) => (k -> v -> v -> v) -> HashMap k v -> HashMap k v
668+
-> HashMap k v
669+
unionWithKey f = go 0
660670
where
661671
-- empty vs. anything
662672
go !_ t1 Empty = t1
663673
go _ Empty t2 = t2
664674
-- leaf vs. leaf
665675
go s t1@(Leaf h1 l1@(L k1 v1)) t2@(Leaf h2 l2@(L k2 v2))
666676
| h1 == h2 = if k1 == k2
667-
then Leaf h1 (L k1 (f v1 v2))
677+
then Leaf h1 (L k1 (f k1 v1 v2))
668678
else collision h1 l1 l2
669679
| otherwise = goDifferentHash s h1 h2 t1 t2
670680
go s t1@(Leaf h1 (L k1 v1)) t2@(Collision h2 ls2)
671-
| h1 == h2 = Collision h1 (updateOrSnocWith f k1 v1 ls2)
681+
| h1 == h2 = Collision h1 (updateOrSnocWithKey f k1 v1 ls2)
672682
| otherwise = goDifferentHash s h1 h2 t1 t2
673683
go s t1@(Collision h1 ls1) t2@(Leaf h2 (L k2 v2))
674-
| h1 == h2 = Collision h1 (updateOrSnocWith (flip f) k2 v2 ls1)
684+
| h1 == h2 = Collision h1 (updateOrSnocWithKey (flip . f) k2 v2 ls1)
675685
| otherwise = goDifferentHash s h1 h2 t1 t2
676686
go s t1@(Collision h1 ls1) t2@(Collision h2 ls2)
677-
| h1 == h2 = Collision h1 (updateOrConcatWith f ls1 ls2)
687+
| h1 == h2 = Collision h1 (updateOrConcatWithKey f ls1 ls2)
678688
| otherwise = goDifferentHash s h1 h2 t1 t2
679689
-- branch vs. branch
680690
go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
@@ -736,7 +746,7 @@ unionWith f = go 0
736746
where
737747
m1 = mask h1 s
738748
m2 = mask h2 s
739-
{-# INLINE unionWith #-}
749+
{-# INLINE unionWithKey #-}
740750

741751
-- | Strict in the result of @f@.
742752
unionArrayBy :: (a -> a -> a) -> Bitmap -> Bitmap -> A.Array a -> A.Array a
@@ -1099,7 +1109,12 @@ updateWith f k0 ary0 = go k0 ary0 0 (A.length ary0)
10991109

11001110
updateOrSnocWith :: Eq k => (v -> v -> v) -> k -> v -> A.Array (Leaf k v)
11011111
-> A.Array (Leaf k v)
1102-
updateOrSnocWith f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
1112+
updateOrSnocWith f = updateOrSnocWithKey (const f)
1113+
{-# INLINABLE updateOrSnocWith #-}
1114+
1115+
updateOrSnocWithKey :: Eq k => (k -> v -> v -> v) -> k -> v -> A.Array (Leaf k v)
1116+
-> A.Array (Leaf k v)
1117+
updateOrSnocWithKey f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
11031118
where
11041119
go !k v !ary !i !n
11051120
| i >= n = A.run $ do
@@ -1109,12 +1124,16 @@ updateOrSnocWith f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
11091124
A.write mary n (L k v)
11101125
return mary
11111126
| otherwise = case A.index ary i of
1112-
(L kx y) | k == kx -> A.update ary i (L k (f v y))
1127+
(L kx y) | k == kx -> A.update ary i (L k (f k v y))
11131128
| otherwise -> go k v ary (i+1) n
1114-
{-# INLINABLE updateOrSnocWith #-}
1129+
{-# INLINABLE updateOrSnocWithKey #-}
11151130

11161131
updateOrConcatWith :: Eq k => (v -> v -> v) -> A.Array (Leaf k v) -> A.Array (Leaf k v) -> A.Array (Leaf k v)
1117-
updateOrConcatWith f ary1 ary2 = A.run $ do
1132+
updateOrConcatWith f = updateOrConcatWithKey (const f)
1133+
{-# INLINABLE updateOrConcatWith #-}
1134+
1135+
updateOrConcatWithKey :: Eq k => (k -> v -> v -> v) -> A.Array (Leaf k v) -> A.Array (Leaf k v) -> A.Array (Leaf k v)
1136+
updateOrConcatWithKey f ary1 ary2 = A.run $ do
11181137
-- first: look up the position of each element of ary2 in ary1
11191138
let indices = A.map (\(L k _) -> indexOf k ary1) ary2
11201139
-- that tells us how large the overlap is:
@@ -1132,14 +1151,14 @@ updateOrConcatWith f ary1 ary2 = A.run $ do
11321151
Just i1 -> do -- key occurs in both arrays, store combination in position i1
11331152
L k v1 <- A.indexM ary1 i1
11341153
L _ v2 <- A.indexM ary2 i2
1135-
A.write mary i1 (L k (f v1 v2))
1154+
A.write mary i1 (L k (f k v1 v2))
11361155
go iEnd (i2+1)
11371156
Nothing -> do -- key is only in ary2, append to end
11381157
A.write mary iEnd =<< A.indexM ary2 i2
11391158
go (iEnd+1) (i2+1)
11401159
go n1 0
11411160
return mary
1142-
{-# INLINABLE updateOrConcatWith #-}
1161+
{-# INLINABLE updateOrConcatWithKey #-}
11431162

11441163
------------------------------------------------------------------------
11451164
-- Manually unrolled loops

Data/HashMap/Lazy.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ module Data.HashMap.Lazy
5454
-- ** Union
5555
, union
5656
, unionWith
57+
, unionWithKey
5758
, unions
5859

5960
-- * Transformations

Data/HashMap/Strict.hs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ module Data.HashMap.Strict
5454
-- ** Union
5555
, union
5656
, unionWith
57+
, unionWithKey
5758
, unions
5859

5960
-- * Transformations
@@ -99,7 +100,7 @@ import qualified Data.HashMap.Base as HM
99100
import Data.HashMap.Base hiding (
100101
alter, adjust, fromList, fromListWith, insert, insertWith, intersectionWith,
101102
intersectionWithKey, map, mapWithKey, mapMaybe, mapMaybeWithKey, singleton,
102-
update, unionWith)
103+
update, unionWith, unionWithKey)
103104
import Data.HashMap.Unsafe (runST)
104105

105106
-- $strictness
@@ -257,25 +258,32 @@ alter f k m =
257258
-- the provided function (first argument) will be used to compute the result.
258259
unionWith :: (Eq k, Hashable k) => (v -> v -> v) -> HashMap k v -> HashMap k v
259260
-> HashMap k v
260-
unionWith f = go 0
261+
unionWith f = unionWithKey (const f)
262+
{-# INLINE unionWith #-}
263+
264+
-- | /O(n+m)/ The union of two maps. If a key occurs in both maps,
265+
-- the provided function (first argument) will be used to compute the result.
266+
unionWithKey :: (Eq k, Hashable k) => (k -> v -> v -> v) -> HashMap k v -> HashMap k v
267+
-> HashMap k v
268+
unionWithKey f = go 0
261269
where
262270
-- empty vs. anything
263271
go !_ t1 Empty = t1
264272
go _ Empty t2 = t2
265273
-- leaf vs. leaf
266274
go s t1@(Leaf h1 l1@(L k1 v1)) t2@(Leaf h2 l2@(L k2 v2))
267275
| h1 == h2 = if k1 == k2
268-
then leaf h1 k1 (f v1 v2)
276+
then leaf h1 k1 (f k1 v1 v2)
269277
else collision h1 l1 l2
270278
| otherwise = goDifferentHash s h1 h2 t1 t2
271279
go s t1@(Leaf h1 (L k1 v1)) t2@(Collision h2 ls2)
272-
| h1 == h2 = Collision h1 (updateOrSnocWith f k1 v1 ls2)
280+
| h1 == h2 = Collision h1 (updateOrSnocWithKey f k1 v1 ls2)
273281
| otherwise = goDifferentHash s h1 h2 t1 t2
274282
go s t1@(Collision h1 ls1) t2@(Leaf h2 (L k2 v2))
275-
| h1 == h2 = Collision h1 (updateOrSnocWith (flip f) k2 v2 ls1)
283+
| h1 == h2 = Collision h1 (updateOrSnocWithKey (flip . f) k2 v2 ls1)
276284
| otherwise = goDifferentHash s h1 h2 t1 t2
277285
go s t1@(Collision h1 ls1) t2@(Collision h2 ls2)
278-
| h1 == h2 = Collision h1 (updateOrConcatWith f ls1 ls2)
286+
| h1 == h2 = Collision h1 (updateOrConcatWithKey f ls1 ls2)
279287
| otherwise = goDifferentHash s h1 h2 t1 t2
280288
-- branch vs. branch
281289
go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
@@ -337,7 +345,7 @@ unionWith f = go 0
337345
where
338346
m1 = mask h1 s
339347
m2 = mask h2 s
340-
{-# INLINE unionWith #-}
348+
{-# INLINE unionWithKey #-}
341349

342350
------------------------------------------------------------------------
343351
-- * Transformations
@@ -446,7 +454,17 @@ updateWith f k0 ary0 = go k0 ary0 0 (A.length ary0)
446454
-- array.
447455
updateOrSnocWith :: Eq k => (v -> v -> v) -> k -> v -> A.Array (Leaf k v)
448456
-> A.Array (Leaf k v)
449-
updateOrSnocWith f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
457+
updateOrSnocWith f = updateOrSnocWithKey (const f)
458+
{-# INLINABLE updateOrSnocWith #-}
459+
460+
-- | Append the given key and value to the array. If the key is
461+
-- already present, instead update the value of the key by applying
462+
-- the given function to the new and old value (in that order). The
463+
-- value is always evaluated to WHNF before being inserted into the
464+
-- array.
465+
updateOrSnocWithKey :: Eq k => (k -> v -> v -> v) -> k -> v -> A.Array (Leaf k v)
466+
-> A.Array (Leaf k v)
467+
updateOrSnocWithKey f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
450468
where
451469
go !k v !ary !i !n
452470
| i >= n = A.run $ do
@@ -457,9 +475,9 @@ updateOrSnocWith f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
457475
A.write mary n l
458476
return mary
459477
| otherwise = case A.index ary i of
460-
(L kx y) | k == kx -> let !v' = f v y in A.update ary i (L k v')
478+
(L kx y) | k == kx -> let !v' = f k v y in A.update ary i (L k v')
461479
| otherwise -> go k v ary (i+1) n
462-
{-# INLINABLE updateOrSnocWith #-}
480+
{-# INLINABLE updateOrSnocWithKey #-}
463481

464482
------------------------------------------------------------------------
465483
-- Smart constructors

tests/HashMapProperties.hs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ pUnionWith :: [(Key, Int)] -> [(Key, Int)] -> Bool
138138
pUnionWith xs ys = M.unionWith (-) (M.fromList xs) `eq_`
139139
HM.unionWith (-) (HM.fromList xs) $ ys
140140

141+
pUnionWithKey :: [(Key, Int)] -> [(Key, Int)] -> Bool
142+
pUnionWithKey xs ys = M.unionWithKey go (M.fromList xs) `eq_`
143+
HM.unionWithKey go (HM.fromList xs) $ ys
144+
where
145+
go :: Key -> Int -> Int -> Int
146+
go (K k) i1 i2 = k - i1 + i2
147+
141148
pUnions :: [[(Key, Int)]] -> Bool
142149
pUnions xss = M.toAscList (M.unions (map M.fromList xss)) ==
143150
toAscList (HM.unions (map HM.fromList xss))
@@ -264,6 +271,7 @@ tests =
264271
-- Combine
265272
, testProperty "union" pUnion
266273
, testProperty "unionWith" pUnionWith
274+
, testProperty "unionWithKey" pUnionWithKey
267275
, testProperty "unions" pUnions
268276
-- Transformations
269277
, testProperty "map" pMap

0 commit comments

Comments
 (0)