- 
                Notifications
    You must be signed in to change notification settings 
- Fork 103
Make intersections much faster #406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 26 commits
21f238b
              16f1f7f
              bcc13fc
              d5262bf
              a16456b
              f72011c
              678a38c
              ec24215
              767ae6e
              72510b4
              fd43ba7
              3612645
              b484042
              9e48bc0
              48119cb
              d9d295d
              88a9c2c
              b3cdbd8
              bf9a27f
              1c20739
              92e4b2a
              5a439cc
              1c118c4
              1256cf3
              b0210c8
              69f8f28
              06cc511
              d9a50d7
              d24cc1f
              64f3f2f
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -78,6 +78,7 @@ module Data.HashMap.Internal | |
| , intersection | ||
| , intersectionWith | ||
| , intersectionWithKey | ||
| , intersectionWithKey# | ||
|  | ||
| -- * Folds | ||
| , foldr' | ||
|  | @@ -143,16 +144,16 @@ import Control.Applicative (Const (..)) | |
| import Control.DeepSeq (NFData (..), NFData1 (..), NFData2 (..)) | ||
| import Control.Monad.ST (ST, runST) | ||
| import Data.Bifoldable (Bifoldable (..)) | ||
| import Data.Bits (complement, popCount, unsafeShiftL, | ||
| unsafeShiftR, (.&.), (.|.), countTrailingZeros) | ||
| import Data.Bits (complement, countTrailingZeros, popCount, | ||
| unsafeShiftL, unsafeShiftR, (.&.), (.|.)) | ||
| import Data.Coerce (coerce) | ||
| import Data.Data (Constr, Data (..), DataType) | ||
| import Data.Functor.Classes (Eq1 (..), Eq2 (..), Ord1 (..), Ord2 (..), | ||
| Read1 (..), Show1 (..), Show2 (..)) | ||
| import Data.Functor.Identity (Identity (..)) | ||
| import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) | ||
| import Data.Hashable (Hashable) | ||
| import Data.Hashable.Lifted (Hashable1, Hashable2) | ||
| import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) | ||
| import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid) | ||
| import GHC.Exts (Int (..), Int#, TYPE, (==#)) | ||
| import GHC.Stack (HasCallStack) | ||
|  | @@ -163,9 +164,9 @@ import Text.Read hiding (step) | |
| import qualified Data.Data as Data | ||
| import qualified Data.Foldable as Foldable | ||
| import qualified Data.Functor.Classes as FC | ||
| import qualified Data.HashMap.Internal.Array as A | ||
| import qualified Data.Hashable as H | ||
| import qualified Data.Hashable.Lifted as H | ||
| import qualified Data.HashMap.Internal.Array as A | ||
| import qualified Data.List as List | ||
|         
                  oberblastmeister marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| import qualified GHC.Exts as Exts | ||
| import qualified Language.Haskell.TH.Syntax as TH | ||
|  | @@ -819,17 +820,9 @@ insertNewKey !h0 !k0 x0 !m0 = go h0 k0 x0 0 m0 | |
| in Full (update32 ary i st') | ||
| where i = index h s | ||
| go h k x s t@(Collision hy v) | ||
| | h == hy = Collision h (snocNewLeaf (L k x) v) | ||
| | h == hy = Collision h (A.snoc v (L k x)) | ||
| | otherwise = | ||
| go h k x s $ BitmapIndexed (mask hy s) (A.singleton t) | ||
| where | ||
| snocNewLeaf :: Leaf k v -> A.Array (Leaf k v) -> A.Array (Leaf k v) | ||
| snocNewLeaf leaf ary = A.run $ do | ||
| let n = A.length ary | ||
| mary <- A.new_ (n + 1) | ||
| A.copy ary 0 mary 0 n | ||
| A.write mary n leaf | ||
| return mary | ||
| {-# NOINLINE insertNewKey #-} | ||
|  | ||
|  | ||
|  | @@ -1008,12 +1001,8 @@ insertModifyingArr :: Eq k => v -> (v -> (# v #)) -> k -> A.Array (Leaf k v) | |
| insertModifyingArr x f k0 ary0 = go k0 ary0 0 (A.length ary0) | ||
| where | ||
| go !k !ary !i !n | ||
| | i >= n = A.run $ do | ||
| -- Not found, append to the end. | ||
| mary <- A.new_ (n + 1) | ||
| A.copy ary 0 mary 0 n | ||
| A.write mary n (L k x) | ||
| return mary | ||
| -- Not found, append to the end. | ||
| | i >= n = A.snoc ary $ L k x | ||
| | otherwise = case A.index ary i of | ||
| (L kx y) | k == kx -> case f y of | ||
| (# y' #) -> if ptrEq y y' | ||
|  | @@ -1639,7 +1628,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do | |
| A.write mary i =<< A.indexM ary2 i2 | ||
| go (i+1) i1 (i2+1) b' | ||
| where | ||
| m = 1 `unsafeShiftL` (countTrailingZeros b) | ||
| m = 1 `unsafeShiftL` countTrailingZeros b | ||
| testBit x = x .&. m /= 0 | ||
| b' = b .&. complement m | ||
| go 0 0 0 bCombined | ||
|  | @@ -1771,37 +1760,149 @@ differenceWith f a b = foldlWithKey' go empty a | |
| -- | /O(n*log m)/ Intersection of two maps. Return elements of the first | ||
| -- map for keys existing in the second. | ||
| intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v | ||
| intersection a b = foldlWithKey' go empty a | ||
| where | ||
| go m k v = case lookup k b of | ||
| Just _ -> unsafeInsert k v m | ||
| _ -> m | ||
| intersection = Exts.inline intersectionWith const | ||
| {-# INLINABLE intersection #-} | ||
|         
                  oberblastmeister marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
|  | ||
| -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps | ||
| -- the provided function is used to combine the values from the two | ||
| -- maps. | ||
| intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 | ||
| -> HashMap k v2 -> HashMap k v3 | ||
| intersectionWith f a b = foldlWithKey' go empty a | ||
| where | ||
| go m k v = case lookup k b of | ||
| Just w -> unsafeInsert k (f v w) m | ||
| _ -> m | ||
| intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 | ||
| intersectionWith f = Exts.inline intersectionWithKey $ const f | ||
| {-# INLINABLE intersectionWith #-} | ||
|  | ||
| -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps | ||
| -- the provided function is used to combine the values from the two | ||
| -- maps. | ||
| intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) | ||
| -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 | ||
| intersectionWithKey f a b = foldlWithKey' go empty a | ||
| where | ||
| go m k v = case lookup k b of | ||
| Just w -> unsafeInsert k (f k v w) m | ||
| _ -> m | ||
| intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 | ||
| intersectionWithKey f = intersectionWithKey# $ \k v1 v2 -> (# f k v1 v2 #) | ||
| {-# INLINABLE intersectionWithKey #-} | ||
|  | ||
| intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 | ||
| intersectionWithKey# f = go 0 | ||
| where | ||
| -- empty vs. anything | ||
| go !_ _ Empty = Empty | ||
| go _ Empty _ = Empty | ||
| -- leaf vs. anything | ||
| go s (Leaf h1 (L k1 v1)) t2 = lookupCont (\_ -> Empty) (\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v') h1 k1 s t2 | ||
| go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v') h2 k2 s t1 | ||
| -- collision vs. collision | ||
| go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2 | ||
| -- branch vs. branch | ||
| go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 | ||
| go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2 | ||
| go s (Full ary1) (BitmapIndexed b2 ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2 | ||
| go s (Full ary1) (Full ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2 | ||
|         
                  oberblastmeister marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| -- collision vs. branch | ||
| go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2) | ||
| | b1 .&. m2 == 0 = Empty | ||
| | otherwise = go (s + bitsPerSubkey) (A.index ary1 i) t2 | ||
| where | ||
| m2 = mask h2 s | ||
| i = sparseIndex b1 m2 | ||
| go s t1@(Collision h1 _ls1) (BitmapIndexed b2 ary2) | ||
| | b2 .&. m1 == 0 = Empty | ||
| | otherwise = go (s + bitsPerSubkey) t1 (A.index ary2 i) | ||
| where | ||
| m1 = mask h1 s | ||
| i = sparseIndex b2 m1 | ||
| go s (Full ary1) t2@(Collision h2 _ls2) = go (s + bitsPerSubkey) (A.index ary1 i) t2 | ||
| where | ||
| i = index h2 s | ||
| go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i) | ||
| where | ||
| i = index h1 s | ||
| {-# INLINE intersectionWithKey# #-} | ||
|  | ||
| intersectionArrayBy :: | ||
| ( HashMap k v1 -> | ||
| HashMap k v2 -> | ||
| HashMap k v3 | ||
| ) -> | ||
| Bitmap -> | ||
| Bitmap -> | ||
| A.Array (HashMap k v1) -> | ||
| A.Array (HashMap k v2) -> | ||
| HashMap k v3 | ||
| intersectionArrayBy f !b1 !b2 !ary1 !ary2 | ||
| | b1 .&. b2 == 0 = Empty | ||
| | otherwise = runST $ do | ||
| mary <- A.new_ $ popCount bIntersect | ||
| -- iterate over nonzero bits of b1 .|. b2 | ||
| let go !i !i1 !i2 !b !bFinal | ||
| | b == 0 = pure (i, bFinal) | ||
| | testBit $ b1 .&. b2 = do | ||
| x1 <- A.indexM ary1 i1 | ||
| x2 <- A.indexM ary2 i2 | ||
| case f x1 x2 of | ||
| Empty -> go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m) | ||
| _ -> do | ||
| A.write mary i $! f x1 x2 | ||
| go (i + 1) (i1 + 1) (i2 + 1) b' bFinal | ||
| | testBit b1 = go i (i1 + 1) i2 b' bFinal | ||
| | otherwise = go i i1 (i2 + 1) b' bFinal | ||
| where | ||
| m = 1 `unsafeShiftL` countTrailingZeros b | ||
| testBit x = x .&. m /= 0 | ||
| b' = b .&. complement m | ||
| (len, bFinal) <- go 0 0 0 bCombined bIntersect | ||
| case len of | ||
| 0 -> pure Empty | ||
| 1 -> A.read mary 0 | ||
| _ -> bitmapIndexedOrFull bFinal <$> (A.unsafeFreeze =<< A.shrink mary len) | ||
| where | ||
| bCombined = b1 .|. b2 | ||
| bIntersect = b1 .&. b2 | ||
|         
                  oberblastmeister marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| {-# INLINE intersectionArrayBy #-} | ||
|  | ||
| intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> Hash -> Hash -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> HashMap k v3 | ||
| intersectionCollisions f h1 h2 ary1 ary2 | ||
| | h1 == h2 = runST $ do | ||
| mary2 <- A.thaw ary2 0 $ A.length ary2 | ||
|          | ||
| mary <- A.new_ $ min (A.length ary1) (A.length ary2) | ||
| let go i j | ||
| | i >= A.length ary1 || j >= A.lengthM mary2 = pure j | ||
|         
                  treeowl marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| | otherwise = do | ||
| L k1 v1 <- A.indexM ary1 i | ||
| searchSwap k1 j mary2 >>= \case | ||
| Just (L _k2 v2) -> do | ||
| let !(# v3 #) = f k1 v1 v2 | ||
| A.write mary j $ L k1 v3 | ||
| go (i + 1) (j + 1) | ||
| Nothing -> do | ||
| go (i + 1) j | ||
| len <- go 0 0 | ||
| case len of | ||
| 0 -> pure Empty | ||
| 1 -> Leaf h1 <$> A.read mary 0 | ||
| _ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) | ||
| | otherwise = Empty | ||
| {-# INLINE intersectionCollisions #-} | ||
|  | ||
| -- | Say we have | ||
| -- @ | ||
| -- 1 2 3 4 | ||
| -- @ | ||
| -- and we search for @3@. Then we can mutate the array to | ||
| -- @ | ||
| -- undefined 2 1 4 | ||
| -- @ | ||
| -- We don't actually need to write undefined, we just have to make sure that the next search starts 1 after the current one. | ||
| searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) | ||
|         
                  sjakobi marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| searchSwap toFind start = go start toFind start | ||
| where | ||
| go i0 k i mary | ||
| | i >= A.lengthM mary = pure Nothing | ||
| | otherwise = do | ||
| l@(L k' _v) <- A.read mary i | ||
| if k == k' | ||
| then do | ||
| A.write mary i =<< A.read mary i0 | ||
| pure $ Just l | ||
| else go i0 k (i + 1) mary | ||
| {-# INLINE searchSwap #-} | ||
|  | ||
|  | ||
| ------------------------------------------------------------------------ | ||
| -- * Folds | ||
|  | ||
|  | @@ -2164,12 +2265,8 @@ updateOrSnocWithKey :: Eq k => (k -> v -> v -> (# v #)) -> k -> v -> A.Array (Le | |
| updateOrSnocWithKey f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0) | ||
| where | ||
| go !k v !ary !i !n | ||
| | i >= n = A.run $ do | ||
| -- Not found, append to the end. | ||
| mary <- A.new_ (n + 1) | ||
| A.copy ary 0 mary 0 n | ||
| A.write mary n (L k v) | ||
| return mary | ||
| -- Not found, append to the end. | ||
| | i >= n = A.snoc ary $ L k v | ||
| | L kx y <- A.index ary i | ||
| , k == kx | ||
| , (# v2 #) <- f k v y | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, the changed sorting of imports is probably due to haskell/stylish-haskell#385, which was recently released.