Skip to content

Commit 5a304f2

Browse files
authored
Merge pull request #146 from mschristiansen/diffwith
Add `differenceWith` function.
2 parents 2bac253 + 679179c commit 5a304f2

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

Data/HashMap/Base.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ module Data.HashMap.Base
4545

4646
-- * Difference and intersection
4747
, difference
48+
, differenceWith
4849
, intersection
4950
, intersectionWith
5051
, intersectionWithKey
@@ -934,6 +935,18 @@ difference a b = foldlWithKey' go empty a
934935
_ -> m
935936
{-# INLINABLE difference #-}
936937

938+
-- | /O(n*log m)/ Difference with a combining function. When two equal keys are
939+
-- encountered, the combining function is applied to the values of these keys.
940+
-- If it returns 'Nothing', the element is discarded (proper set difference). If
941+
-- it returns (@'Just' y@), the element is updated with a new value @y@.
942+
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
943+
differenceWith f a b = foldlWithKey' go empty a
944+
where
945+
go m k v = case lookup k b of
946+
Nothing -> insert k v m
947+
Just w -> maybe m (\y -> insert k y m) (f v w)
948+
{-# INLINABLE differenceWith #-}
949+
937950
-- | /O(n*log m)/ Intersection of two maps. Return elements of the first
938951
-- map for keys existing in the second.
939952
intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v

Data/HashMap/Lazy.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ module Data.HashMap.Lazy
6464

6565
-- * Difference and intersection
6666
, difference
67+
, differenceWith
6768
, intersection
6869
, intersectionWith
6970
, intersectionWithKey

Data/HashMap/Strict.hs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ module Data.HashMap.Strict
6464

6565
-- * Difference and intersection
6666
, difference
67+
, differenceWith
6768
, intersection
6869
, intersectionWith
6970
, intersectionWithKey
@@ -98,9 +99,9 @@ import Prelude hiding (map)
9899
import qualified Data.HashMap.Array as A
99100
import qualified Data.HashMap.Base as HM
100101
import Data.HashMap.Base hiding (
101-
alter, adjust, fromList, fromListWith, insert, insertWith, intersectionWith,
102-
intersectionWithKey, map, mapWithKey, mapMaybe, mapMaybeWithKey, singleton,
103-
update, unionWith, unionWithKey)
102+
alter, adjust, fromList, fromListWith, insert, insertWith, differenceWith,
103+
intersectionWith, intersectionWithKey, map, mapWithKey, mapMaybe,
104+
mapMaybeWithKey, singleton, update, unionWith, unionWithKey)
104105
import Data.HashMap.Unsafe (runST)
105106

106107
-- $strictness
@@ -394,6 +395,18 @@ mapMaybe f = mapMaybeWithKey (const f)
394395
------------------------------------------------------------------------
395396
-- * Difference and intersection
396397

398+
-- | /O(n*log m)/ Difference with a combining function. When two equal keys are
399+
-- encountered, the combining function is applied to the values of these keys.
400+
-- If it returns 'Nothing', the element is discarded (proper set difference). If
401+
-- it returns (@'Just' y@), the element is updated with a new value @y@.
402+
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
403+
differenceWith f a b = foldlWithKey' go empty a
404+
where
405+
go m k v = case HM.lookup k b of
406+
Nothing -> insert k v m
407+
Just w -> maybe m (\y -> insert k y m) (f v w)
408+
{-# INLINABLE differenceWith #-}
409+
397410
-- | /O(n+m)/ Intersection of two maps. If a key occurs in both maps
398411
-- the provided function is used to combine the values from the two
399412
-- maps.

tests/HashMapProperties.hs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ pDifference :: [(Key, Int)] -> [(Key, Int)] -> Bool
163163
pDifference xs ys = M.difference (M.fromList xs) `eq_`
164164
HM.difference (HM.fromList xs) $ ys
165165

166+
pDifferenceWith :: [(Key, Int)] -> [(Key, Int)] -> Bool
167+
pDifferenceWith xs ys = M.differenceWith f (M.fromList xs) `eq_`
168+
HM.differenceWith f (HM.fromList xs) $ ys
169+
where
170+
f x y = if x == 0 then Nothing else Just (x - y)
171+
166172
pIntersection :: [(Key, Int)] -> [(Key, Int)] -> Bool
167173
pIntersection xs ys = M.intersection (M.fromList xs) `eq_`
168174
HM.intersection (HM.fromList xs) $ ys
@@ -284,6 +290,7 @@ tests =
284290
]
285291
, testGroup "difference and intersection"
286292
[ testProperty "difference" pDifference
293+
, testProperty "differenceWith" pDifferenceWith
287294
, testProperty "intersection" pIntersection
288295
, testProperty "intersectionWith" pIntersectionWith
289296
, testProperty "intersectionWithKey" pIntersectionWithKey

0 commit comments

Comments
 (0)