Skip to content

Commit 3ff1715

Browse files
authored
Add RULES for nub functions (#517)
* Add rewrite rules to allow the nub functions in `ListUtils` to participate in fold/build fusion. * For the sake of simplicity, define `nubOrd` and `nubInt` in terms of `nubOrdOn` and `nubIntOn`.
1 parent b7f1c86 commit 3ff1715

File tree

2 files changed

+193
-44
lines changed

2 files changed

+193
-44
lines changed

Data/Containers/ListUtils.hs

Lines changed: 148 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{-# LANGUAGE CPP #-}
2-
#if !defined(TESTING) && __GLASGOW_HASKELL__ >= 703
3-
{-# LANGUAGE Safe #-}
2+
{-# LANGUAGE BangPatterns #-}
3+
#if __GLASGOW_HASKELL__ >= 703
4+
{-# LANGUAGE Trustworthy #-}
45
#endif
56

67
-----------------------------------------------------------------------------
@@ -21,45 +22,160 @@ module Data.Containers.ListUtils (
2122
nubIntOn
2223
) where
2324

25+
import Data.Set (Set)
2426
import qualified Data.Set as Set
2527
import qualified Data.IntSet as IntSet
28+
import Data.IntSet (IntSet)
29+
#ifdef __GLASGOW_HASKELL__
30+
import GHC.Exts ( build )
31+
#endif
2632

27-
-- | /O(n log n)/. The 'nubOrd' function removes duplicate elements from a list.
28-
-- In particular, it keeps only the first occurrence of each element. By using a 'Set' internally
29-
-- it has better asymptotics than the standard 'nub' function.
30-
nubOrd :: (Ord a) => [a] -> [a]
31-
nubOrd = go Set.empty
32-
where
33-
go _ [] = []
34-
go s (x:xs) = if x `Set.member` s then go s xs
35-
else x : go (Set.insert x s) xs
33+
-- *** Ord-based nubbing ***
3634

37-
-- | The `nubOrdOn` function behaves just like `nubOrd` except it performs comparisons not on the
38-
-- original datatype, but a user-specified projection from that datatype.
39-
nubOrdOn :: (Ord b) => (a -> b) -> [a] -> [a]
40-
nubOrdOn f = go Set.empty
35+
36+
-- | \( O(n \log n \). The @nubOrd@ function removes duplicate elements from a list.
37+
-- In particular, it keeps only the first occurrence of each element. By using a
38+
-- 'Set' internally it has better asymptotics than the standard 'Data.List.nub'
39+
-- function.
40+
--
41+
-- ==== Strictness
42+
--
43+
-- @nubOrd@ is strict in the elements of the list.
44+
--
45+
-- ==== Efficiency note
46+
--
47+
-- When applicable, it is almost always better to use 'nubInt' or 'nubIntOn' instead
48+
-- of this function. For example, the best way to nub a list of characters is
49+
--
50+
-- @ nubIntOn fromEnum xs @
51+
nubOrd :: Ord a => [a] -> [a]
52+
nubOrd = nubOrdOn id
53+
{-# INLINE nubOrd #-}
54+
55+
-- | The @nubOrdOn@ function behaves just like 'nubOrd' except it performs
56+
-- comparisons not on the original datatype, but a user-specified projection
57+
-- from that datatype.
58+
--
59+
-- ==== Strictness
60+
--
61+
-- @nubOrdOn@ is strict in the values of the function applied to the
62+
-- elements of the list.
63+
nubOrdOn :: Ord b => (a -> b) -> [a] -> [a]
64+
-- For some reason we need to write an explicit lambda here to allow this
65+
-- to inline when only applied to a function.
66+
nubOrdOn f = \xs -> nubOrdOnExcluding f Set.empty xs
67+
{-# INLINE nubOrdOn #-}
68+
69+
-- Splitting nubOrdOn like this means that we don't have to worry about
70+
-- matching specifically on Set.empty in the rewrite-back rule.
71+
nubOrdOnExcluding :: Ord b => (a -> b) -> Set b -> [a] -> [a]
72+
nubOrdOnExcluding f = go
4173
where
4274
go _ [] = []
43-
go s (x:xs) = let fx = f x
44-
in if fx `Set.member` s then go s xs
45-
else x : go (Set.insert fx s) xs
75+
go s (x:xs)
76+
| fx `Set.member` s = go s xs
77+
| otherwise = x : go (Set.insert fx s) xs
78+
where !fx = f x
79+
80+
#ifdef __GLASGOW_HASKELL__
81+
-- We want this inlinable to specialize to the necessary Ord instance.
82+
{-# INLINABLE [1] nubOrdOnExcluding #-}
83+
84+
{-# RULES
85+
-- Rewrite to a fusible form.
86+
"nubOrdOn" [~1] forall f as s. nubOrdOnExcluding f s as =
87+
build (\c n -> foldr (nubOrdOnFB f c) (constNubOn n) as s)
88+
89+
-- Rewrite back to a plain form
90+
"nubOrdOnList" [1] forall f as s.
91+
foldr (nubOrdOnFB f (:)) (constNubOn []) as s =
92+
nubOrdOnExcluding f s as
93+
#-}
94+
95+
nubOrdOnFB :: Ord b
96+
=> (a -> b)
97+
-> (a -> r -> r)
98+
-> a
99+
-> (Set b -> r)
100+
-> Set b
101+
-> r
102+
nubOrdOnFB f c x r s
103+
| fx `Set.member` s = r s
104+
| otherwise = x `c` r (Set.insert fx s)
105+
where !fx = f x
106+
{-# INLINABLE [0] nubOrdOnFB #-}
46107

47-
-- | /O(n min(n,W))/. The 'nubInt' function removes duplicate elements from a list.
48-
-- In particular, it keeps only the first occurrence of each element. By using an 'IntSet' internally
49-
-- it has better asymptotics than the standard 'nub' function.
108+
constNubOn :: a -> b -> a
109+
constNubOn x _ = x
110+
{-# INLINE [0] constNubOn #-}
111+
#endif
112+
113+
114+
-- *** Int-based nubbing ***
115+
116+
117+
-- | \( O(n \min(n,W)) \). The @nubInt@ function removes duplicate 'Int'
118+
-- values from a list. In particular, it keeps only the first occurrence
119+
-- of each element. By using an 'IntSet' internally, it attains better
120+
-- asymptotics than the standard 'Data.List.nub' function.
121+
--
122+
-- See also 'nubIntOn', a more widely applicable generalization.
123+
--
124+
-- ==== Strictness
125+
--
126+
-- @nubInt@ is strict in the elements of the list.
50127
nubInt :: [Int] -> [Int]
51-
nubInt = go IntSet.empty
52-
where
53-
go _ [] = []
54-
go s (x:xs) = if x `IntSet.member` s then go s xs
55-
else x : go (IntSet.insert x s) xs
128+
nubInt = nubIntOn id
129+
{-# INLINE nubInt #-}
56130

57-
-- | The `nubIntOn` function behaves just like 'nubInt' except it performs comparisons not on the
58-
-- original datatype, but a user-specified projection from that datatype to 'Int'.
131+
-- | The @nubIntOn@ function behaves just like 'nubInt' except it performs
132+
-- comparisons not on the original datatype, but a user-specified projection
133+
-- from that datatype.
134+
--
135+
-- ==== Strictness
136+
--
137+
-- @nubIntOn@ is strict in the values of the function applied to the
138+
-- elements of the list.
59139
nubIntOn :: (a -> Int) -> [a] -> [a]
60-
nubIntOn f = go IntSet.empty
140+
-- For some reason we need to write an explicit lambda here to allow this
141+
-- to inline when only applied to a function.
142+
nubIntOn f = \xs -> nubIntOnExcluding f IntSet.empty xs
143+
{-# INLINE nubIntOn #-}
144+
145+
-- Splitting nubIntOn like this means that we don't have to worry about
146+
-- matching specifically on IntSet.empty in the rewrite-back rule.
147+
nubIntOnExcluding :: (a -> Int) -> IntSet -> [a] -> [a]
148+
nubIntOnExcluding f = go
61149
where
62150
go _ [] = []
63-
go s (x:xs) = let fx = f x
64-
in if fx `IntSet.member` s then go s xs
65-
else x : go (IntSet.insert fx s) xs
151+
go s (x:xs)
152+
| fx `IntSet.member` s = go s xs
153+
| otherwise = x : go (IntSet.insert fx s) xs
154+
where !fx = f x
155+
156+
#ifdef __GLASGOW_HASKELL__
157+
-- We don't mark this INLINABLE because it doesn't seem obviously useful
158+
-- to inline it anywhere; the elements the function operates on are actually
159+
-- pulled from a list and installed in a list; the situation is very different
160+
-- when fusion occurs. In this case, we let GHC make the call.
161+
{-# NOINLINE [1] nubIntOnExcluding #-}
162+
163+
{-# RULES
164+
"nubIntOn" [~1] forall f as s. nubIntOnExcluding f s as =
165+
build (\c n -> foldr (nubIntOnFB f c) (constNubOn n) as s)
166+
"nubIntOnList" [1] forall f as s. foldr (nubIntOnFB f (:)) (constNubOn []) as s =
167+
nubIntOnExcluding f s as
168+
#-}
169+
170+
nubIntOnFB :: (a -> Int)
171+
-> (a -> r -> r)
172+
-> a
173+
-> (IntSet -> r)
174+
-> IntSet
175+
-> r
176+
nubIntOnFB f c x r s
177+
| fx `IntSet.member` s = r s
178+
| otherwise = x `c` r (IntSet.insert fx s)
179+
where !fx = f x
180+
{-# INLINABLE [0] nubIntOnFB #-}
181+
#endif

tests/listutils-properties.hs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,57 @@ import Data.List (nub, nubBy)
44
import Data.Containers.ListUtils
55
import Test.Framework
66
import Test.Framework.Providers.QuickCheck2
7+
import Test.QuickCheck (Property, (===))
8+
import Test.QuickCheck.Function (Fun, apply)
9+
import Test.QuickCheck.Poly (A, OrdA, B, OrdB, C)
710

811
main :: IO ()
912
main = defaultMain
10-
[ testProperty "nubOrd" prop_nubOrd
11-
, testProperty "nubOrdOn" prop_nubOrdOn
12-
, testProperty "nubInt" prop_nubInt
13-
, testProperty "nubIntOn" prop_nubIntOn
13+
[ testProperty "nubOrd" prop_nubOrd
14+
, testProperty "nubOrdOn" prop_nubOrdOn
15+
, testProperty "nubOrdOn fusion" prop_nubOrdOnFusion
16+
, testProperty "nubInt" prop_nubInt
17+
, testProperty "nubIntOn" prop_nubIntOn
18+
, testProperty "nubIntOn fusion" prop_nubIntOnFusion
1419
]
1520

1621

17-
prop_nubOrd :: [Int] -> Bool
18-
prop_nubOrd xs = nubOrd xs == nub xs
22+
prop_nubOrd :: [OrdA] -> Property
23+
prop_nubOrd xs = nubOrd xs === nub xs
1924

20-
prop_nubInt :: [Int] -> Bool
21-
prop_nubInt xs = nubInt xs == nub xs
25+
prop_nubInt :: [Int] -> Property
26+
prop_nubInt xs = nubInt xs === nub xs
2227

23-
prop_nubOrdOn :: [(Int,Int)] -> Bool
24-
prop_nubOrdOn xs = nubOrdOn snd xs == nubBy (\x y -> snd x == snd y) xs
28+
prop_nubOrdOn :: Fun A OrdB -> [A] -> Property
29+
prop_nubOrdOn f' xs =
30+
nubOrdOn f xs === nubBy (\x y -> f x == f y) xs
31+
where f = apply f'
2532

26-
prop_nubIntOn :: [(Int,Int)] -> Bool
27-
prop_nubIntOn xs = nubIntOn snd xs == nubBy (\x y -> snd x == snd y) xs
33+
prop_nubIntOn :: Fun A Int -> [A] -> Property
34+
prop_nubIntOn f' xs =
35+
nubIntOn f xs === nubBy (\x y -> f x == f y) xs
36+
where f = apply f'
37+
38+
prop_nubOrdOnFusion :: Fun B C
39+
-> Fun B OrdB
40+
-> Fun A B
41+
-> [A] -> Property
42+
prop_nubOrdOnFusion f' g' h' xs =
43+
(map f . nubOrdOn g . map h $ xs)
44+
=== (map f . nubBy (\x y -> g x == g y) . map h $ xs)
45+
where
46+
f = apply f'
47+
g = apply g'
48+
h = apply h'
49+
50+
prop_nubIntOnFusion :: Fun B C
51+
-> Fun B Int
52+
-> Fun A B
53+
-> [A] -> Property
54+
prop_nubIntOnFusion f' g' h' xs =
55+
(map f . nubIntOn g . map h $ xs)
56+
=== (map f . nubBy (\x y -> g x == g y) . map h $ xs)
57+
where
58+
f = apply f'
59+
g = apply g'
60+
h = apply h'

0 commit comments

Comments
 (0)