@@ -58,7 +58,8 @@ module Data.Vector.Generic.Mutable (
58
58
ifoldr , ifoldr' , ifoldrM , ifoldrM' ,
59
59
60
60
-- * Modifying vectors
61
- nextPermutation ,
61
+ nextPermutation , nextPermutationBy ,
62
+ prevPermutation , prevPermutationBy ,
62
63
63
64
-- ** Filling and copying
64
65
set , copy , move , unsafeCopy , unsafeMove ,
@@ -91,9 +92,10 @@ import Data.Vector.Internal.Check
91
92
import Control.Monad.Primitive ( PrimMonad (.. ), RealWorld , stToPrim )
92
93
93
94
import Prelude
94
- ( Ord , Monad , Bool (.. ), Int , Maybe (.. ), Either (.. )
95
+ ( Ord , Monad , Bool (.. ), Int , Maybe (.. ), Either (.. ), Ordering ( .. )
95
96
, return , otherwise , flip , const , seq , min , max , not , pure
96
- , (>>=) , (+) , (-) , (<) , (<=) , (>=) , (==) , (/=) , (.) , ($) , (=<<) , (>>) , (<$>) )
97
+ , (>>=) , (+) , (-) , (<) , (<=) , (>) , (>=) , (==) , (/=) , (.) , ($) , (=<<) , (>>) , (<$>) )
98
+ import Data.Bits ( Bits (shiftR ) )
97
99
98
100
#include "vector.h"
99
101
@@ -1213,6 +1215,47 @@ partitionWithUnknown f s
1213
1215
-- Modifying vectors
1214
1216
-- -----------------
1215
1217
1218
+
1219
+ -- | Compute the (lexicographically) next permutation of the given vector in-place.
1220
+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1221
+ -- weakly descending order. In this case the vector will not get updated,
1222
+ -- as opposed to the behavior of the C++ function @std::next_permutation@.
1223
+ nextPermutation :: (PrimMonad m , Ord e , MVector v e ) => v (PrimState m ) e -> m Bool
1224
+ {-# INLINE nextPermutation #-}
1225
+ nextPermutation = nextPermutationByLt (<)
1226
+
1227
+ -- | Compute the (lexicographically) next permutation of the given vector in-place,
1228
+ -- using the provided comparison function.
1229
+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1230
+ -- weakly descending order. In this case the vector will not get updated,
1231
+ -- as opposed to the behavior of the C++ function @std::next_permutation@.
1232
+ --
1233
+ -- @since 0.13.2.0
1234
+ nextPermutationBy :: (PrimMonad m , MVector v e ) => (e -> e -> Ordering ) -> v (PrimState m ) e -> m Bool
1235
+ {-# INLINE nextPermutationBy #-}
1236
+ nextPermutationBy cmp = nextPermutationByLt (\ x y -> cmp x y == LT )
1237
+
1238
+ -- | Compute the (lexicographically) previous permutation of the given vector in-place.
1239
+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1240
+ -- weakly ascending order. In this case the vector will not get updated,
1241
+ -- as opposed to the behavior of the C++ function @std::prev_permutation@.
1242
+ --
1243
+ -- @since 0.13.2.0
1244
+ prevPermutation :: (PrimMonad m , Ord e , MVector v e ) => v (PrimState m ) e -> m Bool
1245
+ {-# INLINE prevPermutation #-}
1246
+ prevPermutation = nextPermutationByLt (>)
1247
+
1248
+ -- | Compute the (lexicographically) previous permutation of the given vector in-place,
1249
+ -- using the provided comparison function.
1250
+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1251
+ -- weakly ascending order. In this case the vector will not get updated,
1252
+ -- as opposed to the behavior of the C++ function @std::prev_permutation@.
1253
+ --
1254
+ -- @since 0.13.2.0
1255
+ prevPermutationBy :: (PrimMonad m , MVector v e ) => (e -> e -> Ordering ) -> v (PrimState m ) e -> m Bool
1256
+ {-# INLINE prevPermutationBy #-}
1257
+ prevPermutationBy cmp = nextPermutationByLt (\ x y -> cmp x y == GT )
1258
+
1216
1259
{-
1217
1260
http://en.wikipedia.org/wiki/Permutation#Algorithms_to_generate_permutations
1218
1261
@@ -1224,30 +1267,51 @@ a given permutation. It changes the given permutation in-place.
1224
1267
2. Find the largest index l greater than k such that a[k] < a[l].
1225
1268
3. Swap the value of a[k] with that of a[l].
1226
1269
4. Reverse the sequence from a[k + 1] up to and including the final element a[n]
1270
+
1271
+ The algorithm has been updated to look up the k in Step 1 beginning from the
1272
+ last of the vector; which renders the algorithm to achieve the average time
1273
+ complexity of O(1) each call. The worst case time complexity is still O(n).
1274
+ The orginal implementation, which scanned the vector from the left, had the
1275
+ time complexity of O(n) on the best case.
1227
1276
-}
1228
1277
1229
1278
-- | Compute the (lexicographically) next permutation of the given vector in-place.
1230
- -- Returns False when the input is the last permutation.
1231
- nextPermutation :: (PrimMonad m ,Ord e ,MVector v e ) => v (PrimState m ) e -> m Bool
1232
- nextPermutation v
1233
- | dim < 2 = return False
1234
- | otherwise = do
1235
- val <- unsafeRead v 0
1236
- (k,l) <- loop val (- 1 ) 0 val 1
1237
- if k < 0
1238
- then return False
1239
- else unsafeSwap v k l >>
1240
- reverse (unsafeSlice (k+ 1 ) (dim- k- 1 ) v) >>
1241
- return True
1242
- where loop ! kval ! k ! l ! prev ! i
1243
- | i == dim = return (k,l)
1244
- | otherwise = do
1245
- cur <- unsafeRead v i
1246
- -- TODO: make tuple unboxed
1247
- let (kval',k') = if prev < cur then (prev,i- 1 ) else (kval,k)
1248
- l' = if kval' < cur then i else l
1249
- loop kval' k' l' cur (i+ 1 )
1250
- dim = length v
1279
+ -- Here, the first argument should be a less-than comparison function.
1280
+ -- Returns False when the input is the last permutation; in this case the vector
1281
+ -- will not get updated, as opposed to the behavior of the C++ function
1282
+ -- @std::next_permutation@.
1283
+ nextPermutationByLt :: (PrimMonad m , MVector v e ) => (e -> e -> Bool ) -> v (PrimState m ) e -> m Bool
1284
+ {-# INLINE nextPermutationByLt #-}
1285
+ nextPermutationByLt lt v
1286
+ | dim < 2 = return False
1287
+ | otherwise = stToPrim $ do
1288
+ ! vlast <- unsafeRead v (dim - 1 )
1289
+ decrLoop (dim - 2 ) vlast
1290
+ where
1291
+ dim = length v
1292
+ -- find the largest index k such that a[k] < a[k + 1], and then pass to the rest.
1293
+ decrLoop ! i ! vi1 | i >= 0 = do
1294
+ ! vi <- unsafeRead v i
1295
+ if vi `lt` vi1 then swapLoop i vi (i+ 1 ) vi1 dim else decrLoop (i- 1 ) vi
1296
+ decrLoop _ ! _ = return False
1297
+ -- find the largest index l greater than k such that a[k] < a[l], and do the rest.
1298
+ swapLoop ! k ! vk = go
1299
+ where
1300
+ -- binary search.
1301
+ go ! l ! vl ! r | r - l <= 1 = do
1302
+ -- Done; do the rest of the algorithm.
1303
+ unsafeWrite v k vl
1304
+ unsafeWrite v l vk
1305
+ reverse $ unsafeSlice (k + 1 ) (dim - k - 1 ) v
1306
+ return True
1307
+ go ! l ! vl ! r = do
1308
+ ! vmid <- unsafeRead v mid
1309
+ if vk `lt` vmid
1310
+ then go mid vmid r
1311
+ else go l vl mid
1312
+ where
1313
+ ! mid = l + (r - l) `shiftR` 1
1314
+
1251
1315
1252
1316
-- $setup
1253
1317
-- >>> import Prelude ((*))
0 commit comments