Skip to content

Commit 2357b66

Browse files
committed
Add size-aware 'union' functions
1 parent da4f3c8 commit 2357b66

File tree

2 files changed

+218
-1
lines changed

2 files changed

+218
-1
lines changed

Data/HashMap/Array.hs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ module Data.HashMap.Array
2424
, indexM
2525
, update
2626
, updateWith'
27+
, updateWithInternal'
2728
, unsafeUpdateM
2829
, insert
2930
, insertM
@@ -32,6 +33,7 @@ module Data.HashMap.Array
3233
, unsafeFreeze
3334
, unsafeThaw
3435
, run
36+
, runInternal
3537
, run2
3638
, copy
3739
, copyM
@@ -232,6 +234,13 @@ run :: (forall s . ST s (MArray s e)) -> Array e
232234
run act = runST $ act >>= unsafeFreeze
233235
{-# INLINE run #-}
234236

237+
runInternal :: (forall s . ST s (Int, MArray s e)) -> (Int, Array e)
238+
runInternal act = runST $ do
239+
(s, mary) <- act
240+
ary <- unsafeFreeze mary
241+
return (s, ary)
242+
{-# INLINE runInternal #-}
243+
235244
run2 :: (forall s. ST s (MArray s e, a)) -> (Array e, a)
236245
run2 k = runST (do
237246
(marr,b) <- k
@@ -297,6 +306,15 @@ updateWith' :: Array e -> Int -> (e -> e) -> Array e
297306
updateWith' ary idx f = update ary idx $! f (index ary idx)
298307
{-# INLINE updateWith' #-}
299308

309+
-- | /O(n)/ Update the element at the given positio in this array, by
310+
-- applying a function to it. Evaluates the element to WHNF before
311+
-- inserting it into the array.
312+
updateWithInternal' :: Array e -> Int -> (e -> (Int, e)) -> (Int, Array e)
313+
updateWithInternal' ary idx f =
314+
let (!sz, !e) = f (index ary idx)
315+
in (sz, update ary idx e)
316+
{-# INLINE updateWithInternal' #-}
317+
300318
-- | /O(1)/ Update the element at the given position in this array,
301319
-- without copying.
302320
unsafeUpdateM :: Array e -> Int -> e -> ST s ()

Data/HashMap/Base.hs

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ type role HashMap nominal representational
160160

161161
-- | WIP. This will become the user-facing 'HashMap' after this PR is
162162
-- finalized.
163-
data HashMapW = HashMapW {-# UNPACK #-} !Int !HashMap
163+
data HashMapW k v = HashMapW {-# UNPACK #-} !Int !(HashMap k v)
164164

165165
instance (NFData k, NFData v) => NFData (HashMap k v) where
166166
rnf Empty = ()
@@ -1033,6 +1033,14 @@ union :: (Eq k, Hashable k) => HashMap k v -> HashMap k v -> HashMap k v
10331033
union = unionWith const
10341034
{-# INLINABLE union #-}
10351035

1036+
-- | /O(n+m)/ The union of two maps. If a key occurs in both maps, the
1037+
-- mapping from the first will be the mapping in the result.
1038+
-- Returns a tuple with the increase in the first hashmap's size and
1039+
-- the union of the two maps.
1040+
unionInternal :: (Eq k, Hashable k) => HashMap k v -> HashMapW k v -> (Int, HashMap k v)
1041+
unionInternal = unionWithInternal const
1042+
{-# INLINABLE unionInternal #-}
1043+
10361044
-- | /O(n+m)/ The union of two maps. If a key occurs in both maps,
10371045
-- the provided function (first argument) will be used to compute the
10381046
-- result.
@@ -1041,6 +1049,20 @@ unionWith :: (Eq k, Hashable k) => (v -> v -> v) -> HashMap k v -> HashMap k v
10411049
unionWith f = unionWithKey (const f)
10421050
{-# INLINE unionWith #-}
10431051

1052+
-- | /O(n+m)/ The union of two maps. If a key occurs in both maps,
1053+
-- the provided function (first argument) will be used to compute the
1054+
-- result.
1055+
-- Returns a tuple with the increase in the first hashmap's size and the
1056+
-- union of the two maps.
1057+
unionWithInternal
1058+
:: (Eq k, Hashable k)
1059+
=> (v -> v -> v)
1060+
-> HashMap k v
1061+
-> HashMapW k v
1062+
-> (Int, HashMap k v)
1063+
unionWithInternal f = unionWithKeyInternal (const f)
1064+
{-# INLINE unionWithInternal #-}
1065+
10441066
-- | /O(n+m)/ The union of two maps. If a key occurs in both maps,
10451067
-- the provided function (first argument) will be used to compute the
10461068
-- result.
@@ -1128,6 +1150,140 @@ unionWithKey f = go 0
11281150
m2 = mask h2 s
11291151
{-# INLINE unionWithKey #-}
11301152

1153+
-- | /O(n+m)/ The union of two maps. If a key occurs in both maps,
1154+
-- the provided function (first argument) will be used to compute the
1155+
-- result.
1156+
-- Returns a tuple where the first component is how many elements were added
1157+
-- to the first hashmap and the second is the union hashmap itself.
1158+
unionWithKeyInternal
1159+
:: forall k v . (Eq k, Hashable k)
1160+
=> (k -> v -> v -> v)
1161+
-> HashMap k v
1162+
-> HashMapW k v
1163+
-> (Int, HashMap k v)
1164+
unionWithKeyInternal f h1 (HashMapW size h2) = go 0 size h1 h2
1165+
where
1166+
go :: Int -> Int -> HashMap k v -> HashMap k v -> (Int, HashMap k v)
1167+
-- empty vs. anything
1168+
go !_ !sz t1 Empty = (sz, t1)
1169+
go _ !sz Empty t2 = (sz, t2)
1170+
-- leaf vs. leaf
1171+
go s !sz t1@(Leaf h1 l1@(L k1 v1)) t2@(Leaf h2 l2@(L k2 v2))
1172+
| h1 == h2 = if k1 == k2
1173+
then (sz - 1, Leaf h1 (L k1 (f k1 v1 v2)))
1174+
else (sz, collision h1 l1 l2)
1175+
| otherwise = goDifferentHash sz s h1 h2 t1 t2 -- don't forget this
1176+
go s !sz t1@(Leaf h1 (L k1 v1)) t2@(Collision h2 ls2)
1177+
| h1 == h2 =
1178+
let !start = A.length ls2
1179+
!newV = updateOrSnocWithKey f k1 v1 ls2
1180+
!end = A.length newV
1181+
in (sz + end - start - 1, Collision h1 newV)
1182+
| otherwise = goDifferentHash sz s h1 h2 t1 t2 -- or this
1183+
go s !sz t1@(Collision h1 ls1) t2@(Leaf h2 (L k2 v2))
1184+
| h1 == h2 =
1185+
let !start = A.length ls1
1186+
!newV = updateOrSnocWithKey (flip . f) k2 v2 ls1
1187+
!end = A.length newV
1188+
in (sz + end - start - 1, Collision h1 newV)
1189+
| otherwise = goDifferentHash sz s h1 h2 t1 t2 -- this too
1190+
go s !sz t1@(Collision h1 ls1) t2@(Collision h2 ls2)
1191+
| h1 == h2 =
1192+
let !start = A.length ls1
1193+
!newV = updateOrConcatWithKey f ls1 ls2
1194+
!end = A.length newV
1195+
in (sz + (end - start - A.length ls2), Collision h1 newV)
1196+
| otherwise = goDifferentHash sz s h1 h2 t1 t2
1197+
-- branch vs. branch
1198+
go s !sz (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
1199+
let b' = b1 .|. b2
1200+
(dsz, ary') =
1201+
unionArrayByInternal sz
1202+
(go (s+bitsPerSubkey))
1203+
b1
1204+
b2
1205+
ary1
1206+
ary2
1207+
in (dsz, bitmapIndexedOrFull b' ary')
1208+
go s !sz (BitmapIndexed b1 ary1) (Full ary2) =
1209+
let (dsz, ary') =
1210+
unionArrayByInternal sz
1211+
(go (s+bitsPerSubkey))
1212+
b1
1213+
fullNodeMask
1214+
ary1
1215+
ary2
1216+
in (dsz, Full ary')
1217+
go s !sz (Full ary1) (BitmapIndexed b2 ary2) =
1218+
let (dsz, ary') =
1219+
unionArrayByInternal sz
1220+
(go (s+bitsPerSubkey))
1221+
fullNodeMask
1222+
b2
1223+
ary1
1224+
ary2
1225+
in (dsz, Full ary')
1226+
go s !sz (Full ary1) (Full ary2) =
1227+
let (dsz, ary') =
1228+
unionArrayByInternal sz
1229+
(go (s+bitsPerSubkey))
1230+
fullNodeMask
1231+
fullNodeMask
1232+
ary1
1233+
ary2
1234+
in (dsz, Full ary')
1235+
-- leaf vs. branch
1236+
go s !sz (BitmapIndexed b1 ary1) t2
1237+
| b1 .&. m2 == 0 = let ary' = A.insert ary1 i t2
1238+
b' = b1 .|. m2
1239+
in (sz, bitmapIndexedOrFull b' ary')
1240+
| otherwise = let (dsz, ary') = A.updateWithInternal' ary1 i $ \st1 ->
1241+
go (s+bitsPerSubkey) sz st1 t2
1242+
in (dsz, BitmapIndexed b1 ary')
1243+
where
1244+
h2 = leafHashCode t2
1245+
m2 = mask h2 s
1246+
i = sparseIndex b1 m2
1247+
go s !sz t1 (BitmapIndexed b2 ary2)
1248+
| b2 .&. m1 == 0 = let ary' = A.insert ary2 i $! t1
1249+
b' = b2 .|. m1
1250+
in (sz, bitmapIndexedOrFull b' ary')
1251+
| otherwise = let (dsz, ary') = A.updateWithInternal' ary2 i $ \st2 ->
1252+
go (s+bitsPerSubkey) sz t1 st2
1253+
in (dsz, BitmapIndexed b2 ary')
1254+
where
1255+
h1 = leafHashCode t1
1256+
m1 = mask h1 s
1257+
i = sparseIndex b2 m1
1258+
go s !sz (Full ary1) t2 =
1259+
let h2 = leafHashCode t2
1260+
i = index h2 s
1261+
(dsz, ary') =
1262+
update16WithInternal' ary1 i $ \st1 ->
1263+
go (s+bitsPerSubkey) sz st1 t2
1264+
in (dsz, Full ary')
1265+
go s !sz t1 (Full ary2) =
1266+
let h1 = leafHashCode t1
1267+
i = index h1 s
1268+
(dsz, ary') =
1269+
update16WithInternal' ary2 i $ \st2 ->
1270+
go (s+bitsPerSubkey) sz t1 st2
1271+
in (dsz, Full ary')
1272+
1273+
leafHashCode (Leaf h _) = h
1274+
leafHashCode (Collision h _) = h
1275+
leafHashCode _ = error "leafHashCode"
1276+
1277+
goDifferentHash sz s h1 h2 t1 t2
1278+
| m1 == m2 = let (!dsz, !hm) = go sz (s+bitsPerSubkey) t1 t2
1279+
in (dsz, BitmapIndexed m1 (A.singleton hm))
1280+
| m1 < m2 = (sz, BitmapIndexed (m1 .|. m2) (A.pair t1 t2))
1281+
| otherwise = (sz, BitmapIndexed (m1 .|. m2) (A.pair t2 t1))
1282+
where
1283+
m1 = mask h1 s
1284+
m2 = mask h2 s
1285+
{-# INLINE unionWithKeyInternal #-}
1286+
11311287
-- | Strict in the result of @f@.
11321288
unionArrayBy :: (a -> a -> a) -> Bitmap -> Bitmap -> A.Array a -> A.Array a
11331289
-> A.Array a
@@ -1156,6 +1312,42 @@ unionArrayBy f b1 b2 ary1 ary2 = A.run $ do
11561312
-- where we copy one array, and then update.
11571313
{-# INLINE unionArrayBy #-}
11581314

1315+
-- | Strict in the result of @f@.
1316+
unionArrayByInternal
1317+
:: Int
1318+
-> (Int -> a -> a -> (Int, a))
1319+
-> Bitmap
1320+
-> Bitmap
1321+
-> A.Array a
1322+
-> A.Array a
1323+
-> (Int, A.Array a)
1324+
unionArrayByInternal size f b1 b2 ary1 ary2 = A.runInternal $ do
1325+
let b' = b1 .|. b2
1326+
mary <- A.new_ (popCount b')
1327+
-- iterate over nonzero bits of b1 .|. b2
1328+
-- it would be nice if we could shift m by more than 1 each time
1329+
let ba = b1 .&. b2
1330+
-- go :: forall s . Int -> Int -> Int -> Int -> Bitmap -> ST s Int
1331+
go !sz !i !i1 !i2 !m
1332+
| m > b' = return sz
1333+
| b' .&. m == 0 = go sz i i1 i2 (m `unsafeShiftL` 1)
1334+
| ba .&. m /= 0 = do
1335+
let (!dsz, !hm) = f sz (A.index ary1 i1) (A.index ary2 i2)
1336+
A.write mary i hm
1337+
go dsz (i+1) (i1+1) (i2+1) (m `unsafeShiftL` 1)
1338+
| b1 .&. m /= 0 = do
1339+
A.write mary i =<< A.indexM ary1 i1
1340+
go sz (i+1) (i1+1) (i2 ) (m `unsafeShiftL` 1)
1341+
| otherwise = do
1342+
A.write mary i =<< A.indexM ary2 i2
1343+
go sz (i+1) (i1 ) (i2+1) (m `unsafeShiftL` 1)
1344+
d <- go size 0 0 0 (b' .&. negate b') -- XXX: b' must be non-zero
1345+
return (d, mary)
1346+
-- TODO: For the case where b1 .&. b2 == b1, i.e. when one is a
1347+
-- subset of the other, we could use a slightly simpler algorithm,
1348+
-- where we copy one array, and then update.
1349+
{-# INLINE unionArrayByInternal #-}
1350+
11591351
-- TODO: Figure out the time complexity of 'unions'.
11601352

11611353
-- | Construct a set containing all elements from a list of sets.
@@ -1679,6 +1871,13 @@ update16With' :: A.Array e -> Int -> (e -> e) -> A.Array e
16791871
update16With' ary idx f = update16 ary idx $! f (A.index ary idx)
16801872
{-# INLINE update16With' #-}
16811873

1874+
-- | /O(n)/ Update the element at the given position in this array, by applying a function to it.
1875+
update16WithInternal' :: A.Array e -> Int -> (e -> (Int, e)) -> (Int, A.Array e)
1876+
update16WithInternal' ary idx f =
1877+
let (s, x) = f $! A.index ary idx
1878+
in (s, update16 ary idx x)
1879+
{-# INLINE update16WithInternal' #-}
1880+
16821881
-- | Unsafely clone an array of 16 elements. The length of the input
16831882
-- array is not checked.
16841883
clone16 :: A.Array e -> ST s (A.MArray s e)

0 commit comments

Comments
 (0)