Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
21f238b
fast intersection
oberblastmeister Apr 9, 2022
16f1f7f
cleanup
oberblastmeister Apr 9, 2022
bcc13fc
add show back
oberblastmeister Apr 9, 2022
d5262bf
inline
oberblastmeister Apr 9, 2022
a16456b
debug checks
oberblastmeister Apr 9, 2022
f72011c
inline function
oberblastmeister Apr 9, 2022
678a38c
refactor to use snoc
oberblastmeister Apr 9, 2022
ec24215
Try the unboxed result thing
treeowl Apr 9, 2022
767ae6e
Remove redundant internal constraint
treeowl Apr 9, 2022
72510b4
Merge pull request #3 from treeowl/unboxedness
oberblastmeister Apr 9, 2022
fd43ba7
shrink compat
oberblastmeister Apr 9, 2022
3612645
fix import
oberblastmeister Apr 9, 2022
b484042
use clone
oberblastmeister Apr 9, 2022
9e48bc0
oof
oberblastmeister Apr 9, 2022
48119cb
don't shrink to zero
oberblastmeister Apr 9, 2022
d9d295d
Leaf special case
oberblastmeister Apr 9, 2022
88a9c2c
add strict verisons
oberblastmeister Apr 10, 2022
b3cdbd8
Update Data/HashMap/Internal.hs
oberblastmeister Apr 11, 2022
bf9a27f
Update Data/HashSet/Internal.hs
oberblastmeister Apr 11, 2022
1c20739
naming
oberblastmeister Apr 11, 2022
92e4b2a
Exts.inline
oberblastmeister Apr 11, 2022
5a439cc
add haddocks for searchSwap
oberblastmeister Apr 12, 2022
1c118c4
cleanup
oberblastmeister Apr 12, 2022
1256cf3
Update Data/HashMap/Internal/Array.hs
oberblastmeister Apr 12, 2022
b0210c8
Update Data/HashMap/Internal/Array.hs
oberblastmeister Apr 12, 2022
69f8f28
refactor
oberblastmeister Apr 13, 2022
06cc511
formatting
oberblastmeister Apr 13, 2022
d9a50d7
breakup lines
oberblastmeister Apr 13, 2022
d24cc1f
use Exts.inline
oberblastmeister Apr 14, 2022
64f3f2f
Merge branch 'master' into fast-intersection
oberblastmeister Apr 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 132 additions & 31 deletions Data/HashMap/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,16 @@ import Control.Applicative (Const (..))
import Control.DeepSeq (NFData (..), NFData1 (..), NFData2 (..))
import Control.Monad.ST (ST, runST)
import Data.Bifoldable (Bifoldable (..))
import Data.Bits (complement, popCount, unsafeShiftL,
unsafeShiftR, (.&.), (.|.), countTrailingZeros)
import Data.Bits (complement, countTrailingZeros, popCount,
unsafeShiftL, unsafeShiftR, (.&.), (.|.))
import Data.Coerce (coerce)
import Data.Data (Constr, Data (..), DataType)
import Data.Functor.Classes (Eq1 (..), Eq2 (..), Ord1 (..), Ord2 (..),
Read1 (..), Show1 (..), Show2 (..))
import Data.Functor.Identity (Identity (..))
import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare)
import Data.Hashable (Hashable)
import Data.Hashable.Lifted (Hashable1, Hashable2)
import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, the changed sorting of imports is probably due to haskell/stylish-haskell#385, which was recently released.

import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid)
import GHC.Exts (Int (..), Int#, TYPE, (==#))
import GHC.Stack (HasCallStack)
Expand All @@ -163,9 +163,9 @@ import Text.Read hiding (step)
import qualified Data.Data as Data
import qualified Data.Foldable as Foldable
import qualified Data.Functor.Classes as FC
import qualified Data.HashMap.Internal.Array as A
import qualified Data.Hashable as H
import qualified Data.Hashable.Lifted as H
import qualified Data.HashMap.Internal.Array as A
import qualified Data.List as List
import qualified GHC.Exts as Exts
import qualified Language.Haskell.TH.Syntax as TH
Expand All @@ -178,7 +178,7 @@ hash :: H.Hashable a => a -> Hash
hash = fromIntegral . H.hash

data Leaf k v = L !k v
deriving (Eq)
deriving (Show, Eq)

instance (NFData k, NFData v) => NFData (Leaf k v) where
rnf (L k v) = rnf k `seq` rnf v
Expand Down Expand Up @@ -210,6 +210,7 @@ data HashMap k v
| Leaf !Hash !(Leaf k v)
| Full !(A.Array (HashMap k v))
| Collision !Hash !(A.Array (Leaf k v))
deriving (Show)

type role HashMap nominal representational

Expand Down Expand Up @@ -337,9 +338,9 @@ instance (Eq k, Hashable k, Read k, Read e) => Read (HashMap k e) where

readListPrec = readListPrecDefault

instance (Show k, Show v) => Show (HashMap k v) where
showsPrec d m = showParen (d > 10) $
showString "fromList " . shows (toList m)
-- instance (Show k, Show v) => Show (HashMap k v) where
-- showsPrec d m = showParen (d > 10) $
-- showString "fromList " . shows (toList m)

instance Traversable (HashMap k) where
traverse f = traverseWithKey (const f)
Expand Down Expand Up @@ -1602,10 +1603,6 @@ unionWithKey f = go 0
ary' = update32With' ary2 i $ \st2 -> go (s+bitsPerSubkey) t1 st2
in Full ary'

leafHashCode (Leaf h _) = h
leafHashCode (Collision h _) = h
leafHashCode _ = error "leafHashCode"

goDifferentHash s h1 h2 t1 t2
| m1 == m2 = BitmapIndexed m1 (A.singleton $! goDifferentHash (s+bitsPerSubkey) h1 h2 t1 t2)
| m1 < m2 = BitmapIndexed (m1 .|. m2) (A.pair t1 t2)
Expand Down Expand Up @@ -1639,7 +1636,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do
A.write mary i =<< A.indexM ary2 i2
go (i+1) i1 (i2+1) b'
where
m = 1 `unsafeShiftL` (countTrailingZeros b)
m = 1 `unsafeShiftL` countTrailingZeros b
testBit x = x .&. m /= 0
b' = b .&. complement m
go 0 0 0 bCombined
Expand Down Expand Up @@ -1771,37 +1768,135 @@ differenceWith f a b = foldlWithKey' go empty a
-- | /O(n*log m)/ Intersection of two maps. Return elements of the first
-- map for keys existing in the second.
intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v
intersection a b = foldlWithKey' go empty a
where
go m k v = case lookup k b of
Just _ -> unsafeInsert k v m
_ -> m
intersection = intersectionWith const
{-# INLINABLE intersection #-}

-- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps
-- the provided function is used to combine the values from the two
-- maps.
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1
-> HashMap k v2 -> HashMap k v3
intersectionWith f a b = foldlWithKey' go empty a
where
go m k v = case lookup k b of
Just w -> unsafeInsert k (f v w) m
_ -> m
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWith f = intersectionWithKey $ const f
{-# INLINABLE intersectionWith #-}

-- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps
-- the provided function is used to combine the values from the two
-- maps.
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3)
-> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWithKey f a b = foldlWithKey' go empty a
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWithKey f = go 0
where
go m k v = case lookup k b of
Just w -> unsafeInsert k (f k v w) m
_ -> m
-- empty vs. anything
go !_ _ Empty = Empty
go _ Empty _ = Empty
-- leaf vs. anything
go s (Leaf h1 (L k1 v1)) t2 = lookupCont (\_ -> Empty) (\v _ -> Leaf h1 $ L k1 $ f k1 v1 v) h1 k1 s t2
go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> Leaf h2 $ L k2 $ f k2 v v2) h2 k2 s t1
-- collision vs. collision
go _ (Collision h1 ls1) (Collision h2 ls2)
| h1 == h2 = if A.length ls == 0 then Empty else Collision h1 ls
| otherwise = Empty
where
ls = intersectionUnorderedArrayWithKey (\k v1 v2 -> (# f k v1 v2 #)) ls1 ls2
-- branch vs. branch
go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArray s b1 b2 ary1 ary2
go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArray s b1 fullNodeMask ary1 ary2
go s (Full ary1) (BitmapIndexed b2 ary2) = intersectionArray s fullNodeMask b2 ary1 ary2
go s (Full ary1) (Full ary2) = intersectionArray s fullNodeMask fullNodeMask ary1 ary2
-- collision vs. branch
go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2)
| b1 .&. m2 == 0 = Empty
| otherwise = go (s + bitsPerSubkey) (A.index ary1 i) t2
where
m2 = mask h2 s
i = sparseIndex b1 m2
go s t1@(Collision h1 _ls1) (BitmapIndexed b2 ary2)
| b2 .&. m1 == 0 = Empty
| otherwise = go (s + bitsPerSubkey) t1 (A.index ary2 i)
where
m1 = mask h1 s
i = sparseIndex b2 m1
go s (Full ary1) t2@(Collision h2 _ls2) = go (s + bitsPerSubkey) (A.index ary1 i) t2
where
i = index h2 s
go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i)
where
i = index h1 s

intersectionArray s b1 b2 ary1 ary2 = normalize b ary
where
(b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2

normalize b ary
| A.length ary == 0 = Empty
| otherwise = bitmapIndexedOrFull b ary
{-# INLINABLE intersectionWithKey #-}

intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v))
intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True
{-# INLINE intersectionArrayBy #-}

intersectionArrayByFilter :: (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array v3)
intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do
let bCombined = b1 .|. b2
mary <- A.new_ $ popCount $ b1 .&. b2
-- iterate over nonzero bits of b1 .&. b2
let go !i !i1 !i2 !b !bFinal
| b == 0 = pure (i, bFinal)
| testBit $ b1 .&. b2 = do
x1 <- A.indexM ary1 i1
x2 <- A.indexM ary2 i2
let !x = f x1 x2
if p x
then do
A.write mary i $! f x1 x2
go (i + 1) (i1 + 1) (i2 + 1) b' bFinal
else go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m)
| testBit b1 = go i (i1 + 1) i2 b' bFinal
| otherwise = go i i1 (i2 + 1) b' bFinal
where
m = 1 `unsafeShiftL` countTrailingZeros b
testBit x = x .&. m /= 0
b' = b .&. complement m
(maryLen, bFinal) <- go 0 0 0 bCombined (b1 .&. b2)
A.shrink mary maryLen
ary <- A.unsafeFreeze mary
pure (bFinal, ary)
{-# INLINE intersectionArrayByFilter #-}

intersectionUnorderedArrayWithKey :: (Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> A.Array (Leaf k v3)
intersectionUnorderedArrayWithKey f ary1 ary2 = A.run $ do
mary2 <- A.thaw ary2 0 $ A.length ary2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we actually need to allocate two arrays for this. The alternative would be to perform the search-and-swap operations on the output array itself.

It might be a bit tricky though – maybe leave it for a follow-up PR, so this one doesn't get too huge.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue with this is that the type could change. For example if we have two arrays with the numbers as keys, and the arrays are both different types
1 2 3 4
3 4 2 1
Let's thaw the first array, and mutate it to
(f 3 3) 2 1 4
f 3 3 could change the type to be something difference than the 2 1 4.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, good point. Unsafe coercions might work for this, but I'd prefer not trying this in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only an issue for intersectionWithKey and such; intersection itself has no type issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another thing about intersection that's special: we can reuse the leaves.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be better if intersection had custom code for handling collisions. Maybe this can be achieved by changing intersectionWithKey# to something similar to filterMapAux.

I'd slightly prefer if we'd leave this for a follow-up PR though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have recorded these ideas in #415.

mary <- A.new_ $ A.length ary1 + A.length ary2
let go i j
| i >= A.length ary1 || j >= A.lengthM mary2 = pure j
| otherwise = do
L k1 v1 <- A.indexM ary1 i
searchSwap k1 j mary2 >>= \case
Just (L _k2 v2) -> do
let !(# v3 #) = f k1 v1 v2
A.write mary j $ L k1 v3
go (i + 1) (j + 1)
Nothing -> do
go (i + 1) j
maryLen <- go 0 0
A.shrink mary maryLen
pure mary
{-# INLINABLE intersectionUnorderedArrayWithKey #-}

searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v))
searchSwap toFind start = go start toFind start
where
go i0 k i mary
| i >= A.lengthM mary = pure Nothing
| otherwise = do
l@(L k' _v) <- A.read mary i
if k == k'
then do
A.write mary i =<< A.read mary i0
pure $ Just l
else go i0 k (i + 1) mary
{-# INLINE searchSwap #-}


------------------------------------------------------------------------
-- * Folds

Expand Down Expand Up @@ -2282,6 +2377,12 @@ ptrEq :: a -> a -> Bool
ptrEq x y = Exts.isTrue# (Exts.reallyUnsafePtrEquality# x y ==# 1#)
{-# INLINE ptrEq #-}

leafHashCode :: HashMap k v -> Hash
leafHashCode (Leaf h _) = h
leafHashCode (Collision h _) = h
leafHashCode _ = error "leafHashCode"
{-# INLINE leafHashCode #-}

------------------------------------------------------------------------
-- IsList instance
instance (Eq k, Hashable k) => Exts.IsList (HashMap k v) where
Expand Down
9 changes: 8 additions & 1 deletion Data/HashMap/Internal/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ module Data.HashMap.Internal.Array
, toList
, fromList
, fromList'
, shrink
) where

import Control.Applicative (liftA2)
Expand All @@ -90,7 +91,7 @@ import GHC.Exts (Int (..), SmallArray#, SmallMutableArray#,
sizeofSmallMutableArray#, tagToEnum#,
thawSmallArray#, unsafeCoerce#,
unsafeFreezeSmallArray#, unsafeThawSmallArray#,
writeSmallArray#)
writeSmallArray#, shrinkSmallMutableArray#)
import GHC.ST (ST (..))
import Prelude hiding (all, filter, foldMap, foldl, foldr, length,
map, read, traverse)
Expand Down Expand Up @@ -204,6 +205,12 @@ new _n@(I# n#) b =
new_ :: Int -> ST s (MArray s a)
new_ n = new n undefinedElem

shrink :: MArray s a -> Int -> ST s ()
shrink mary (I# n#) =
ST $ \s -> case shrinkSmallMutableArray# (unMArray mary) n# s of
s' -> (# s', () #)
{-# INLINE shrink #-}

singleton :: a -> Array a
singleton x = runST (singletonM x)
{-# INLINE singleton #-}
Expand Down
2 changes: 1 addition & 1 deletion Data/HashSet/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ difference (HashSet a) (HashSet b) = HashSet (H.difference a b)
--
-- >>> HashSet.intersection (HashSet.fromList [1,2,3]) (HashSet.fromList [2,3,4])
-- fromList [2,3]
intersection :: (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a
intersection :: (Show a, Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a
intersection (HashSet a) (HashSet b) = HashSet (H.intersection a b)
{-# INLINABLE intersection #-}

Expand Down
25 changes: 21 additions & 4 deletions tests/Properties/HashMapLazy.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# OPTIONS_GHC -fno-warn-orphans #-} -- because of Arbitrary (HashMap k v)
{-# LANGUAGE BangPatterns #-}

-- | Tests for the 'Data.HashMap.Lazy' module. We test functions by
-- comparing them to @Map@ from @containers@.
Expand All @@ -17,6 +18,7 @@ import Control.Applicative (Const (..))
import Control.Monad (guard)
import Data.Bifoldable
import Data.Function (on)
import Debug.Trace (traceId)
import Data.Functor.Identity (Identity (..))
import Data.Hashable (Hashable (hashWithSalt))
import Data.Ord (comparing)
Expand All @@ -26,6 +28,7 @@ import Test.QuickCheck.Function (Fun, apply)
import Test.QuickCheck.Poly (A, B)
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.QuickCheck (testProperty)
import Test.Tasty.HUnit

import qualified Data.Foldable as Foldable
import qualified Data.List as List
Expand All @@ -42,7 +45,7 @@ import qualified Data.Map.Lazy as M

-- Key type that generates more hash collisions.
newtype Key = K { unK :: Int }
deriving (Arbitrary, Eq, Ord, Read, Show)
deriving (Arbitrary, Eq, Ord, Read, Show, Num)

instance Hashable Key where
hashWithSalt salt k = hashWithSalt salt (unK k) `mod` 20
Expand Down Expand Up @@ -249,7 +252,15 @@ pSubmapDifference m1 m2 = HM.isSubmapOf (HM.difference m1 m2) m1

pNotSubmapDifference :: HashMap Key Int -> HashMap Key Int -> Property
pNotSubmapDifference m1 m2 =
not (HM.null (HM.intersection m1 m2)) ==>
not (HM.null (HM.intersection m1 m2)) ==> do

let
res = HM.intersection m1 m2
res' = M.intersection (M.fromList $ HM.toList m1) (M.fromList $ HM.toList m2)
-- !_ = traceId $ "res: " ++ show res
-- !_ = traceId $ "res': " ++ show res'
-- !_ = traceId $ "m1: " ++ show m1
-- !_ = traceId $ "m2: " ++ show m2
not (HM.isSubmapOf m1 (HM.difference m1 m2))

pSubmapDelete :: HashMap Key Int -> Property
Expand Down Expand Up @@ -318,8 +329,13 @@ pDifferenceWith xs ys = M.differenceWith f (M.fromList xs) `eq_`
f x y = if x == 0 then Nothing else Just (x - y)

pIntersection :: [(Key, Int)] -> [(Key, Int)] -> Bool
pIntersection xs ys = M.intersection (M.fromList xs) `eq_`
HM.intersection (HM.fromList xs) $ ys
pIntersection xs ys =
M.intersection (M.fromList xs)
`eq_` HM.intersection (HM.fromList xs)
$ ys

intersectionBad :: Assertion
intersectionBad = pIntersection [(-20, 0), (0, 0)] [(0, 0), (20, 0)] @? "should be true"

pIntersectionWith :: [(Key, Int)] -> [(Key, Int)] -> Bool
pIntersectionWith xs ys = M.intersectionWith (-) (M.fromList xs) `eq_`
Expand Down Expand Up @@ -531,6 +547,7 @@ tests =
[ testProperty "difference" pDifference
, testProperty "differenceWith" pDifferenceWith
, testProperty "intersection" pIntersection
, testCase "intersectionBad" intersectionBad
, testProperty "intersectionWith" pIntersectionWith
, testProperty "intersectionWithKey" pIntersectionWithKey
]
Expand Down