@@ -35,6 +35,7 @@ module Data.HashMap.Base
35
35
-- ** Union
36
36
, union
37
37
, unionWith
38
+ , unionWithKey
38
39
, unions
39
40
40
41
-- * Transformations
@@ -86,6 +87,7 @@ module Data.HashMap.Base
86
87
, update16M
87
88
, update16With'
88
89
, updateOrConcatWith
90
+ , updateOrConcatWithKey
89
91
, filterMapAux
90
92
) where
91
93
@@ -656,25 +658,33 @@ union = unionWith const
656
658
-- result.
657
659
unionWith :: (Eq k , Hashable k ) => (v -> v -> v ) -> HashMap k v -> HashMap k v
658
660
-> HashMap k v
659
- unionWith f = go 0
661
+ unionWith f = unionWithKey (const f)
662
+ {-# INLINE unionWith #-}
663
+
664
+ -- | /O(n+m)/ The union of two maps. If a key occurs in both maps,
665
+ -- the provided function (first argument) will be used to compute the
666
+ -- result.
667
+ unionWithKey :: (Eq k , Hashable k ) => (k -> v -> v -> v ) -> HashMap k v -> HashMap k v
668
+ -> HashMap k v
669
+ unionWithKey f = go 0
660
670
where
661
671
-- empty vs. anything
662
672
go ! _ t1 Empty = t1
663
673
go _ Empty t2 = t2
664
674
-- leaf vs. leaf
665
675
go s t1@ (Leaf h1 l1@ (L k1 v1)) t2@ (Leaf h2 l2@ (L k2 v2))
666
676
| h1 == h2 = if k1 == k2
667
- then Leaf h1 (L k1 (f v1 v2))
677
+ then Leaf h1 (L k1 (f k1 v1 v2))
668
678
else collision h1 l1 l2
669
679
| otherwise = goDifferentHash s h1 h2 t1 t2
670
680
go s t1@ (Leaf h1 (L k1 v1)) t2@ (Collision h2 ls2)
671
- | h1 == h2 = Collision h1 (updateOrSnocWith f k1 v1 ls2)
681
+ | h1 == h2 = Collision h1 (updateOrSnocWithKey f k1 v1 ls2)
672
682
| otherwise = goDifferentHash s h1 h2 t1 t2
673
683
go s t1@ (Collision h1 ls1) t2@ (Leaf h2 (L k2 v2))
674
- | h1 == h2 = Collision h1 (updateOrSnocWith (flip f) k2 v2 ls1)
684
+ | h1 == h2 = Collision h1 (updateOrSnocWithKey (flip . f) k2 v2 ls1)
675
685
| otherwise = goDifferentHash s h1 h2 t1 t2
676
686
go s t1@ (Collision h1 ls1) t2@ (Collision h2 ls2)
677
- | h1 == h2 = Collision h1 (updateOrConcatWith f ls1 ls2)
687
+ | h1 == h2 = Collision h1 (updateOrConcatWithKey f ls1 ls2)
678
688
| otherwise = goDifferentHash s h1 h2 t1 t2
679
689
-- branch vs. branch
680
690
go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
@@ -736,7 +746,7 @@ unionWith f = go 0
736
746
where
737
747
m1 = mask h1 s
738
748
m2 = mask h2 s
739
- {-# INLINE unionWith #-}
749
+ {-# INLINE unionWithKey #-}
740
750
741
751
-- | Strict in the result of @f@.
742
752
unionArrayBy :: (a -> a -> a ) -> Bitmap -> Bitmap -> A. Array a -> A. Array a
@@ -1099,7 +1109,12 @@ updateWith f k0 ary0 = go k0 ary0 0 (A.length ary0)
1099
1109
1100
1110
updateOrSnocWith :: Eq k => (v -> v -> v ) -> k -> v -> A. Array (Leaf k v )
1101
1111
-> A. Array (Leaf k v )
1102
- updateOrSnocWith f k0 v0 ary0 = go k0 v0 ary0 0 (A. length ary0)
1112
+ updateOrSnocWith f = updateOrSnocWithKey (const f)
1113
+ {-# INLINABLE updateOrSnocWith #-}
1114
+
1115
+ updateOrSnocWithKey :: Eq k => (k -> v -> v -> v ) -> k -> v -> A. Array (Leaf k v )
1116
+ -> A. Array (Leaf k v )
1117
+ updateOrSnocWithKey f k0 v0 ary0 = go k0 v0 ary0 0 (A. length ary0)
1103
1118
where
1104
1119
go ! k v ! ary ! i ! n
1105
1120
| i >= n = A. run $ do
@@ -1109,12 +1124,16 @@ updateOrSnocWith f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0)
1109
1124
A. write mary n (L k v)
1110
1125
return mary
1111
1126
| otherwise = case A. index ary i of
1112
- (L kx y) | k == kx -> A. update ary i (L k (f v y))
1127
+ (L kx y) | k == kx -> A. update ary i (L k (f k v y))
1113
1128
| otherwise -> go k v ary (i+ 1 ) n
1114
- {-# INLINABLE updateOrSnocWith #-}
1129
+ {-# INLINABLE updateOrSnocWithKey #-}
1115
1130
1116
1131
updateOrConcatWith :: Eq k => (v -> v -> v ) -> A. Array (Leaf k v ) -> A. Array (Leaf k v ) -> A. Array (Leaf k v )
1117
- updateOrConcatWith f ary1 ary2 = A. run $ do
1132
+ updateOrConcatWith f = updateOrConcatWithKey (const f)
1133
+ {-# INLINABLE updateOrConcatWith #-}
1134
+
1135
+ updateOrConcatWithKey :: Eq k => (k -> v -> v -> v ) -> A. Array (Leaf k v ) -> A. Array (Leaf k v ) -> A. Array (Leaf k v )
1136
+ updateOrConcatWithKey f ary1 ary2 = A. run $ do
1118
1137
-- first: look up the position of each element of ary2 in ary1
1119
1138
let indices = A. map (\ (L k _) -> indexOf k ary1) ary2
1120
1139
-- that tells us how large the overlap is:
@@ -1132,14 +1151,14 @@ updateOrConcatWith f ary1 ary2 = A.run $ do
1132
1151
Just i1 -> do -- key occurs in both arrays, store combination in position i1
1133
1152
L k v1 <- A. indexM ary1 i1
1134
1153
L _ v2 <- A. indexM ary2 i2
1135
- A. write mary i1 (L k (f v1 v2))
1154
+ A. write mary i1 (L k (f k v1 v2))
1136
1155
go iEnd (i2+ 1 )
1137
1156
Nothing -> do -- key is only in ary2, append to end
1138
1157
A. write mary iEnd =<< A. indexM ary2 i2
1139
1158
go (iEnd+ 1 ) (i2+ 1 )
1140
1159
go n1 0
1141
1160
return mary
1142
- {-# INLINABLE updateOrConcatWith #-}
1161
+ {-# INLINABLE updateOrConcatWithKey #-}
1143
1162
1144
1163
------------------------------------------------------------------------
1145
1164
-- Manually unrolled loops
0 commit comments