Skip to content

Commit eb60526

Browse files
authored
Merge pull request #498 from gksato/optimize-nextperm
Optimize Mutable.nextPermutation and add {next/prev}permutation(By)
2 parents fd76994 + 7415257 commit eb60526

File tree

11 files changed

+458
-51
lines changed

11 files changed

+458
-51
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
{-# LANGUAGE BangPatterns #-}
2+
{-# LANGUAGE FlexibleContexts #-}
3+
module Bench.Vector.Algo.NextPermutation (generatePermTests) where
4+
5+
import qualified Data.Vector.Unboxed as V
6+
import qualified Data.Vector.Unboxed.Mutable as M
7+
import qualified Data.Vector.Generic.Mutable as G
8+
import System.Random.Stateful
9+
( StatefulGen, UniformRange(uniformRM) )
10+
11+
-- | Generate a list of benchmarks for permutation algorithms.
12+
-- The list contains pairs of benchmark names and corresponding actions.
13+
-- The actions are to be executed by the benchmarking framework.
14+
--
15+
-- The list contains the following benchmarks:
16+
-- - @(next|prev)Permutation@ on a small vector repeated until the end of the permutation cycle
17+
-- - Bijective versions of @(next|prev)Permutation@ on a vector of size @n@, repeated @n@ times
18+
-- - ascending permutation
19+
-- - descending permutation
20+
-- - random permutation
21+
-- - Baseline for bijective versions: just copying a vector of size @n@. Note that the tests for
22+
-- bijective versions begins with copying a vector.
23+
generatePermTests :: StatefulGen g IO => g -> Int -> IO [(String, IO ())]
24+
generatePermTests gen useSize = do
25+
let !k = useSizeToPermLen useSize
26+
let !vasc = V.generate useSize id
27+
!vdesc = V.generate useSize (useSize-1-)
28+
!vrnd <- randomPermutationWith gen useSize
29+
return
30+
[ ("nextPermutation (small vector, until end)", loopPermutations k)
31+
, ("nextPermutationBijective (ascending perm of size n, n times)", repeatNextPermutation vasc useSize)
32+
, ("nextPermutationBijective (descending perm of size n, n times)", repeatNextPermutation vdesc useSize)
33+
, ("nextPermutationBijective (random perm of size n, n times)", repeatNextPermutation vrnd useSize)
34+
, ("prevPermutation (small vector, until end)", loopRevPermutations k)
35+
, ("prevPermutationBijective (ascending perm of size n, n times)", repeatPrevPermutation vasc useSize)
36+
, ("prevPermutationBijective (descending perm of size n, n times)", repeatPrevPermutation vdesc useSize)
37+
, ("prevPermutationBijective (random perm of size n, n times)", repeatPrevPermutation vrnd useSize)
38+
, ("baseline for *Bijective (just copying the vector of size n)", V.thaw vrnd >> return ())
39+
]
40+
41+
-- | Given a PRNG and a length @n@, generate a random permutation of @[0..n-1]@.
42+
randomPermutationWith :: (StatefulGen g IO) => g -> Int -> IO (V.Vector Int)
43+
randomPermutationWith gen n = do
44+
v <- M.generate n id
45+
V.forM_ (V.generate (n-1) id) $ \ !i -> do
46+
j <- uniformRM (i,n-1) gen
47+
M.swap v i j
48+
V.unsafeFreeze v
49+
50+
-- | Given @useSize@ benchmark option, compute the largest @n <= 12@ such that @n! <= useSize@.
51+
-- Repeat-nextPermutation-until-end benchmark will use @n@ as the length of the vector.
52+
-- Note that 12 is the largest @n@ such that @n!@ can be represented as an 'Int32'.
53+
useSizeToPermLen :: Int -> Int
54+
useSizeToPermLen us = case V.findIndex (> max 0 us) $ V.scanl' (*) 1 $ V.generate 12 (+1) of
55+
Just i -> i-1
56+
Nothing -> 12
57+
58+
-- | A bijective version of @G.nextPermutation@ that reverses the vector
59+
-- if it is already in descending order.
60+
-- "Bijective" here means that the function forms a cycle over all permutations
61+
-- of the vector's elements.
62+
--
63+
-- This has a nice property that should be benchmarked:
64+
-- this function takes amortized constant time each call,
65+
-- if successively called either Omega(n) times on a single vector having distinct elements,
66+
-- or arbitrary times on a single vector initially in strictly ascending order.
67+
nextPermutationBijective :: (G.MVector v a, Ord a) => v G.RealWorld a -> IO Bool
68+
nextPermutationBijective v = do
69+
res <- G.nextPermutation v
70+
if res then return True else G.reverse v >> return False
71+
72+
-- | A bijective version of @G.prevPermutation@ that reverses the vector
73+
-- if it is already in ascending order.
74+
-- "Bijective" here means that the function forms a cycle over all permutations
75+
-- of the vector's elements.
76+
--
77+
-- This has a nice property that should be benchmarked:
78+
-- this function takes amortized constant time each call,
79+
-- if successively called either Omega(n) times on a single vector having distinct elements,
80+
-- or arbitrary times on a single vector initially in strictly descending order.
81+
prevPermutationBijective :: (G.MVector v a, Ord a) => v G.RealWorld a -> IO Bool
82+
prevPermutationBijective v = do
83+
res <- G.prevPermutation v
84+
if res then return True else G.reverse v >> return False
85+
86+
-- | Repeat @nextPermutation@ on @[0..n-1]@ until the end.
87+
loopPermutations :: Int -> IO ()
88+
loopPermutations n = do
89+
v <- M.generate n id
90+
let loop = do
91+
res <- M.nextPermutation v
92+
if res then loop else return ()
93+
loop
94+
95+
-- | Repeat @prevPermutation@ on @[n-1,n-2..0]@ until the end.
96+
loopRevPermutations :: Int -> IO ()
97+
loopRevPermutations n = do
98+
v <- M.generate n (n-1-)
99+
let loop = do
100+
res <- M.prevPermutation v
101+
if res then loop else return ()
102+
loop
103+
104+
-- | Repeat @nextPermutationBijective@ on a given vector given times.
105+
repeatNextPermutation :: V.Vector Int -> Int -> IO ()
106+
repeatNextPermutation !v !n = do
107+
!mv <- V.thaw v
108+
let loop !i | i <= 0 = return ()
109+
loop !i = do
110+
_ <- nextPermutationBijective mv
111+
loop (i-1)
112+
loop n
113+
114+
-- | Repeat @prevPermutationBijective@ on a given vector given times.
115+
repeatPrevPermutation :: V.Vector Int -> Int -> IO ()
116+
repeatPrevPermutation !v !n = do
117+
!mv <- V.thaw v
118+
let loop !i | i <= 0 = return ()
119+
loop !i = do
120+
_ <- prevPermutationBijective mv
121+
loop (i-1)
122+
loop n

vector/benchmarks/Main.hs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
{-# LANGUAGE BangPatterns #-}
22
module Main where
33

4-
import Bench.Vector.Algo.MutableSet (mutableSet)
5-
import Bench.Vector.Algo.ListRank (listRank)
6-
import Bench.Vector.Algo.Rootfix (rootfix)
7-
import Bench.Vector.Algo.Leaffix (leaffix)
8-
import Bench.Vector.Algo.AwShCC (awshcc)
9-
import Bench.Vector.Algo.HybCC (hybcc)
10-
import Bench.Vector.Algo.Quickhull (quickhull)
11-
import Bench.Vector.Algo.Spectral (spectral)
12-
import Bench.Vector.Algo.Tridiag (tridiag)
13-
import Bench.Vector.Algo.FindIndexR (findIndexR, findIndexR_naive, findIndexR_manual)
4+
import Bench.Vector.Algo.MutableSet (mutableSet)
5+
import Bench.Vector.Algo.ListRank (listRank)
6+
import Bench.Vector.Algo.Rootfix (rootfix)
7+
import Bench.Vector.Algo.Leaffix (leaffix)
8+
import Bench.Vector.Algo.AwShCC (awshcc)
9+
import Bench.Vector.Algo.HybCC (hybcc)
10+
import Bench.Vector.Algo.Quickhull (quickhull)
11+
import Bench.Vector.Algo.Spectral (spectral)
12+
import Bench.Vector.Algo.Tridiag (tridiag)
13+
import Bench.Vector.Algo.FindIndexR (findIndexR, findIndexR_naive, findIndexR_manual)
14+
import Bench.Vector.Algo.NextPermutation (generatePermTests)
1415

1516
import Bench.Vector.TestData.ParenTree (parenTree)
1617
import Bench.Vector.TestData.Graph (randomGraph)
@@ -50,6 +51,7 @@ main = do
5051
!ds <- randomVector useSize
5152
!sp <- randomVector (floor $ sqrt $ fromIntegral useSize)
5253
vi <- MV.new useSize
54+
permTests <- generatePermTests gen useSize
5355

5456
defaultMainWithIngredients ingredients $ bgroup "All"
5557
[ bench "listRank" $ whnf listRank useSize
@@ -66,4 +68,5 @@ main = do
6668
, bench "findIndexR_manual" $ whnf findIndexR_manual ((<indexFindThreshold), as)
6769
, bench "minimumOn" $ whnf (U.minimumOn (\x -> x*x*x)) as
6870
, bench "maximumOn" $ whnf (U.maximumOn (\x -> x*x*x)) as
71+
, bgroup "(next|prev)Permutation" $ map (\(name, act) -> bench name $ whnfIO act) permTests
6972
]

vector/changelog.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# Changes in version 0.13.2.0
2+
3+
* We had some improvements on `*.Mutable.{next,prev}Permutation{,By}`
4+
[#498](https://github.com/haskell/vector/pull/498):
5+
* Add `*.Mutable.prevPermutation{,By}` and `*.Mutable.nextPermutationBy`
6+
* Improve time performance. We may now expect good specialization supported by inlining.
7+
The implementation has also been algorithmically updated: in the previous implementation
8+
the full enumeration of all the permutations of `[1..n]` took Omega(n*n!), but it now takes O(n!).
9+
* Add tests for `{next,prev}Permutation`
10+
* Add benchmarks for `{next,prev}Permutation`
11+
112
# Changes in version 0.13.1.0
213

314
* Specialized variants of `findIndexR` are reexported for all vector

vector/src/Data/Vector/Generic/Mutable.hs

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ module Data.Vector.Generic.Mutable (
5858
ifoldr, ifoldr', ifoldrM, ifoldrM',
5959

6060
-- * Modifying vectors
61-
nextPermutation,
61+
nextPermutation, nextPermutationBy,
62+
prevPermutation, prevPermutationBy,
6263

6364
-- ** Filling and copying
6465
set, copy, move, unsafeCopy, unsafeMove,
@@ -91,9 +92,10 @@ import Data.Vector.Internal.Check
9192
import Control.Monad.Primitive ( PrimMonad(..), RealWorld, stToPrim )
9293

9394
import Prelude
94-
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..)
95+
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..), Ordering(..)
9596
, return, otherwise, flip, const, seq, min, max, not, pure
96-
, (>>=), (+), (-), (<), (<=), (>=), (==), (/=), (.), ($), (=<<), (>>), (<$>) )
97+
, (>>=), (+), (-), (<), (<=), (>), (>=), (==), (/=), (.), ($), (=<<), (>>), (<$>) )
98+
import Data.Bits ( Bits(shiftR) )
9799

98100
#include "vector.h"
99101

@@ -1213,6 +1215,47 @@ partitionWithUnknown f s
12131215
-- Modifying vectors
12141216
-- -----------------
12151217

1218+
1219+
-- | Compute the (lexicographically) next permutation of the given vector in-place.
1220+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
1221+
-- weakly descending order. In this case the vector will not get updated,
1222+
-- as opposed to the behavior of the C++ function @std::next_permutation@.
1223+
nextPermutation :: (PrimMonad m, Ord e, MVector v e) => v (PrimState m) e -> m Bool
1224+
{-# INLINE nextPermutation #-}
1225+
nextPermutation = nextPermutationByLt (<)
1226+
1227+
-- | Compute the (lexicographically) next permutation of the given vector in-place,
1228+
-- using the provided comparison function.
1229+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
1230+
-- weakly descending order. In this case the vector will not get updated,
1231+
-- as opposed to the behavior of the C++ function @std::next_permutation@.
1232+
--
1233+
-- @since 0.13.2.0
1234+
nextPermutationBy :: (PrimMonad m, MVector v e) => (e -> e -> Ordering) -> v (PrimState m) e -> m Bool
1235+
{-# INLINE nextPermutationBy #-}
1236+
nextPermutationBy cmp = nextPermutationByLt (\x y -> cmp x y == LT)
1237+
1238+
-- | Compute the (lexicographically) previous permutation of the given vector in-place.
1239+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
1240+
-- weakly ascending order. In this case the vector will not get updated,
1241+
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
1242+
--
1243+
-- @since 0.13.2.0
1244+
prevPermutation :: (PrimMonad m, Ord e, MVector v e) => v (PrimState m) e -> m Bool
1245+
{-# INLINE prevPermutation #-}
1246+
prevPermutation = nextPermutationByLt (>)
1247+
1248+
-- | Compute the (lexicographically) previous permutation of the given vector in-place,
1249+
-- using the provided comparison function.
1250+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
1251+
-- weakly ascending order. In this case the vector will not get updated,
1252+
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
1253+
--
1254+
-- @since 0.13.2.0
1255+
prevPermutationBy :: (PrimMonad m, MVector v e) => (e -> e -> Ordering) -> v (PrimState m) e -> m Bool
1256+
{-# INLINE prevPermutationBy #-}
1257+
prevPermutationBy cmp = nextPermutationByLt (\x y -> cmp x y == GT)
1258+
12161259
{-
12171260
http://en.wikipedia.org/wiki/Permutation#Algorithms_to_generate_permutations
12181261
@@ -1224,30 +1267,51 @@ a given permutation. It changes the given permutation in-place.
12241267
2. Find the largest index l greater than k such that a[k] < a[l].
12251268
3. Swap the value of a[k] with that of a[l].
12261269
4. Reverse the sequence from a[k + 1] up to and including the final element a[n]
1270+
1271+
The algorithm has been updated to look up the k in Step 1 beginning from the
1272+
last of the vector; which renders the algorithm to achieve the average time
1273+
complexity of O(1) each call. The worst case time complexity is still O(n).
1274+
The orginal implementation, which scanned the vector from the left, had the
1275+
time complexity of O(n) on the best case.
12271276
-}
12281277

12291278
-- | Compute the (lexicographically) next permutation of the given vector in-place.
1230-
-- Returns False when the input is the last permutation.
1231-
nextPermutation :: (PrimMonad m,Ord e,MVector v e) => v (PrimState m) e -> m Bool
1232-
nextPermutation v
1233-
| dim < 2 = return False
1234-
| otherwise = do
1235-
val <- unsafeRead v 0
1236-
(k,l) <- loop val (-1) 0 val 1
1237-
if k < 0
1238-
then return False
1239-
else unsafeSwap v k l >>
1240-
reverse (unsafeSlice (k+1) (dim-k-1) v) >>
1241-
return True
1242-
where loop !kval !k !l !prev !i
1243-
| i == dim = return (k,l)
1244-
| otherwise = do
1245-
cur <- unsafeRead v i
1246-
-- TODO: make tuple unboxed
1247-
let (kval',k') = if prev < cur then (prev,i-1) else (kval,k)
1248-
l' = if kval' < cur then i else l
1249-
loop kval' k' l' cur (i+1)
1250-
dim = length v
1279+
-- Here, the first argument should be a less-than comparison function.
1280+
-- Returns False when the input is the last permutation; in this case the vector
1281+
-- will not get updated, as opposed to the behavior of the C++ function
1282+
-- @std::next_permutation@.
1283+
nextPermutationByLt :: (PrimMonad m, MVector v e) => (e -> e -> Bool) -> v (PrimState m) e -> m Bool
1284+
{-# INLINE nextPermutationByLt #-}
1285+
nextPermutationByLt lt v
1286+
| dim < 2 = return False
1287+
| otherwise = stToPrim $ do
1288+
!vlast <- unsafeRead v (dim - 1)
1289+
decrLoop (dim - 2) vlast
1290+
where
1291+
dim = length v
1292+
-- find the largest index k such that a[k] < a[k + 1], and then pass to the rest.
1293+
decrLoop !i !vi1 | i >= 0 = do
1294+
!vi <- unsafeRead v i
1295+
if vi `lt` vi1 then swapLoop i vi (i+1) vi1 dim else decrLoop (i-1) vi
1296+
decrLoop _ !_ = return False
1297+
-- find the largest index l greater than k such that a[k] < a[l], and do the rest.
1298+
swapLoop !k !vk = go
1299+
where
1300+
-- binary search.
1301+
go !l !vl !r | r - l <= 1 = do
1302+
-- Done; do the rest of the algorithm.
1303+
unsafeWrite v k vl
1304+
unsafeWrite v l vk
1305+
reverse $ unsafeSlice (k + 1) (dim - k - 1) v
1306+
return True
1307+
go !l !vl !r = do
1308+
!vmid <- unsafeRead v mid
1309+
if vk `lt` vmid
1310+
then go mid vmid r
1311+
else go l vl mid
1312+
where
1313+
!mid = l + (r - l) `shiftR` 1
1314+
12511315

12521316
-- $setup
12531317
-- >>> import Prelude ((*))

vector/src/Data/Vector/Mutable.hs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ module Data.Vector.Mutable (
5858
ifoldr, ifoldr', ifoldrM, ifoldrM',
5959

6060
-- * Modifying vectors
61-
nextPermutation,
61+
nextPermutation, nextPermutationBy,
62+
prevPermutation, prevPermutationBy,
6263

6364
-- ** Filling and copying
6465
set, copy, move, unsafeCopy, unsafeMove,
@@ -574,11 +575,45 @@ unsafeMove = G.unsafeMove
574575
-- -----------------
575576

576577
-- | Compute the (lexicographically) next permutation of the given vector in-place.
577-
-- Returns False when the input is the last permutation.
578+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
579+
-- weakly descending order. In this case the vector will not get updated,
580+
-- as opposed to the behavior of the C++ function @std::next_permutation@.
578581
nextPermutation :: (PrimMonad m, Ord e) => MVector (PrimState m) e -> m Bool
579582
{-# INLINE nextPermutation #-}
580583
nextPermutation = G.nextPermutation
581584

585+
-- | Compute the (lexicographically) next permutation of the given vector in-place,
586+
-- using the provided comparison function.
587+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
588+
-- weakly descending order. In this case the vector will not get updated,
589+
-- as opposed to the behavior of the C++ function @std::next_permutation@.
590+
--
591+
-- @since 0.13.2.0
592+
nextPermutationBy :: PrimMonad m => (e -> e -> Ordering) -> MVector (PrimState m) e -> m Bool
593+
{-# INLINE nextPermutationBy #-}
594+
nextPermutationBy = G.nextPermutationBy
595+
596+
-- | Compute the (lexicographically) previous permutation of the given vector in-place.
597+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
598+
-- weakly ascending order. In this case the vector will not get updated,
599+
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
600+
--
601+
-- @since 0.13.2.0
602+
prevPermutation :: (PrimMonad m, Ord e) => MVector (PrimState m) e -> m Bool
603+
{-# INLINE prevPermutation #-}
604+
prevPermutation = G.prevPermutation
605+
606+
-- | Compute the (lexicographically) previous permutation of the given vector in-place,
607+
-- using the provided comparison function.
608+
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
609+
-- weakly ascending order. In this case the vector will not get updated,
610+
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
611+
--
612+
-- @since 0.13.2.0
613+
prevPermutationBy :: PrimMonad m => (e -> e -> Ordering) -> MVector (PrimState m) e -> m Bool
614+
{-# INLINE prevPermutationBy #-}
615+
prevPermutationBy = G.prevPermutationBy
616+
582617
-- Folds
583618
-- -----
584619

0 commit comments

Comments
 (0)