|
| 1 | +-- | |
| 2 | +-- Module : Streamly.Benchmark.Data.Scanl |
| 3 | +-- Copyright : (c) 2018 Composewell |
| 4 | +-- |
| 5 | +-- License : MIT |
| 6 | +-- Maintainer : streamly@composewell.com |
| 7 | + |
| 8 | +{-# LANGUAGE CPP #-} |
| 9 | +{-# LANGUAGE FlexibleContexts #-} |
| 10 | +{-# LANGUAGE FlexibleInstances #-} |
| 11 | +{-# LANGUAGE RankNTypes #-} |
| 12 | +{-# LANGUAGE ScopedTypeVariables #-} |
| 13 | +{-# OPTIONS_GHC -Wno-orphans #-} |
| 14 | + |
| 15 | +module Main (main) where |
| 16 | + |
| 17 | +import Data.IORef (newIORef, readIORef, modifyIORef) |
| 18 | +import Control.DeepSeq (NFData(..)) |
| 19 | +import Data.Functor.Identity (Identity(..)) |
| 20 | +import System.Random (randomRIO) |
| 21 | + |
| 22 | +import Streamly.Internal.Data.Stream (Stream) |
| 23 | +import Streamly.Internal.Data.Scanl (Scanl(..)) |
| 24 | +import Streamly.Internal.Data.MutArray (MutArray) |
| 25 | + |
| 26 | +import qualified Streamly.Internal.Data.Fold as FL |
| 27 | +import qualified Streamly.Internal.Data.Scanl as Scanl |
| 28 | +import qualified Streamly.Internal.Data.Stream as Stream |
| 29 | + |
| 30 | +import Test.Tasty.Bench hiding (env) |
| 31 | +import Streamly.Benchmark.Common |
| 32 | +import Prelude hiding (last, length, all, any, take, unzip, sequence_, filter) |
| 33 | + |
| 34 | +import qualified Data.Set as Set |
| 35 | +import System.IO.Unsafe (unsafePerformIO) |
| 36 | + |
| 37 | +------------------------------------------------------------------------------- |
| 38 | +-- Helpers |
| 39 | +------------------------------------------------------------------------------- |
| 40 | + |
| 41 | +{-# INLINE source #-} |
| 42 | +source :: (Monad m, Num a, Stream.Enumerable a) => |
| 43 | + Int -> a -> Stream m a |
| 44 | +source len from = |
| 45 | + Stream.enumerateFromThenTo from (from + 1) (from + fromIntegral len) |
| 46 | + |
| 47 | +{-# INLINE benchScanWith #-} |
| 48 | +benchScanWith :: Num a => |
| 49 | + (Int -> a -> Stream IO a) -> Int -> String -> Scanl IO a b -> Benchmark |
| 50 | +benchScanWith src len name f = |
| 51 | + bench name |
| 52 | + $ nfIO |
| 53 | + $ randomRIO (1, 1 :: Int) |
| 54 | + >>= Stream.fold FL.drain |
| 55 | + . Stream.postscanl f |
| 56 | + . src len |
| 57 | + . fromIntegral |
| 58 | + |
| 59 | +{-# INLINE benchWithPostscan #-} |
| 60 | +benchWithPostscan :: Int -> String -> Scanl IO Int a -> Benchmark |
| 61 | +benchWithPostscan = benchScanWith source |
| 62 | + |
| 63 | +------------------------------------------------------------------------------- |
| 64 | +-- Benchmarks |
| 65 | +------------------------------------------------------------------------------- |
| 66 | + |
| 67 | +moduleName :: String |
| 68 | +moduleName = "Data.Scanl" |
| 69 | + |
| 70 | +instance NFData (MutArray a) where |
| 71 | + {-# INLINE rnf #-} |
| 72 | + rnf _ = () |
| 73 | + |
| 74 | +instance NFData a => NFData (Stream Identity a) where |
| 75 | + {-# INLINE rnf #-} |
| 76 | + rnf xs = runIdentity $ Stream.fold (FL.foldl' (\_ x -> rnf x) ()) xs |
| 77 | + |
| 78 | +o_n_heap_serial :: Int -> [Benchmark] |
| 79 | +o_n_heap_serial value = |
| 80 | + [ bgroup "key-value" |
| 81 | + [ |
| 82 | + benchWithPostscan value "demuxIO (1-shot) (64 buckets) [sum 100]" |
| 83 | + $ Scanl.demuxIO (getKey 64) getScanl |
| 84 | + , benchWithPostscan value "demuxIO (64 buckets) [sum]" |
| 85 | + $ Scanl.demuxIO (getKey 64) (const (pure (Just Scanl.sum))) |
| 86 | + , benchWithPostscan value "classifyIO (64 buckets) [sum 100]" |
| 87 | + $ Scanl.classifyIO (getKey 64) (limitedSum 100) |
| 88 | + , benchWithPostscan value "classifyIO (64 buckets) [sum]" |
| 89 | + $ Scanl.classifyIO (getKey 64) Scanl.sum |
| 90 | + ] |
| 91 | + ] |
| 92 | + |
| 93 | + where |
| 94 | + |
| 95 | + limitedSum n = Scanl.take n Scanl.sum |
| 96 | + |
| 97 | + getKey buckets = (`mod` buckets) |
| 98 | + |
| 99 | + afterDone action (Scanl.Scanl step i e f) = Scanl.Scanl step1 i e f |
| 100 | + where |
| 101 | + step1 x a = do |
| 102 | + res <- step x a |
| 103 | + case res of |
| 104 | + Scanl.Partial s1 -> pure $ Scanl.Partial s1 |
| 105 | + Scanl.Done b -> action >> pure (Scanl.Done b) |
| 106 | + |
| 107 | + ref = unsafePerformIO $ newIORef Set.empty |
| 108 | + getScanl k = do |
| 109 | + set <- readIORef ref |
| 110 | + if Set.member k set |
| 111 | + then pure Nothing |
| 112 | + else pure |
| 113 | + $ Just |
| 114 | + $ afterDone (modifyIORef ref (Set.insert k)) (limitedSum 100) |
| 115 | + |
| 116 | +------------------------------------------------------------------------------- |
| 117 | +-- Driver |
| 118 | +------------------------------------------------------------------------------- |
| 119 | + |
| 120 | +main :: IO () |
| 121 | +main = runWithCLIOpts defaultStreamSize allBenchmarks |
| 122 | + |
| 123 | + where |
| 124 | + |
| 125 | + allBenchmarks value = |
| 126 | + [ bgroup (o_n_heap_prefix moduleName) (o_n_heap_serial value) |
| 127 | + ] |
0 commit comments