diff --git a/Numeric/Sum.hs b/Numeric/Sum.hs index f73ebe5..73d52ec 100644 --- a/Numeric/Sum.hs +++ b/Numeric/Sum.hs @@ -1,5 +1,6 @@ {-# LANGUAGE BangPatterns, DeriveDataTypeable, FlexibleContexts, - MultiParamTypeClasses, TypeFamilies, CPP #-} + FlexibleInstances, MultiParamTypeClasses, ScopedTypeVariables, + TypeFamilies, CPP #-} {-# OPTIONS_GHC -fno-warn-name-shadowing #-} -- | -- Module : Numeric.Sum @@ -69,25 +70,21 @@ import qualified Data.Vector.Generic.Mutable as GM import qualified Data.Vector.Unboxed as U -- | A class for summation of floating point numbers. -class Summation s where +class RealFloat a => Summation s a where -- | The identity for summation. - zero :: s + zero :: s a -- | Add a value to a sum. - add :: s -> Double -> s + add :: s a -> a -> s a -- | Sum a collection of values. -- -- Example: -- @foo = 'Numeric.Sum.sum' 'kbn' [1,2,3]@ - sum :: (F.Foldable f) => (s -> Double) -> f Double -> Double - sum f = f . F.foldl' add zero + sum :: F.Foldable f => (s a -> a) -> f a -> a + sum f = f . F.foldl' add (zero :: s a) {-# INLINE sum #-} -instance Summation Double where - zero = 0 - add = (+) - -- | Kahan summation. This is the least accurate of the compensated -- summation methods. In practice, it only beats naive summation for -- inputs with large magnitude. Kahan summation can be /less/ @@ -96,12 +93,12 @@ instance Summation Double where -- This summation method is included for completeness. Its use is not -- recommended. In practice, 'KBNSum' is both 30% faster and more -- accurate. -data KahanSum = KahanSum {-# UNPACK #-} !Double {-# UNPACK #-} !Double +data KahanSum a = KahanSum !a !a deriving (Eq, Show, Typeable, Data) -instance U.Unbox KahanSum -newtype instance U.MVector s KahanSum = MV_KahanSum (U.MVector s (Double, Double)) -instance MVector U.MVector KahanSum where +instance U.Unbox a => U.Unbox (KahanSum a) +newtype instance U.MVector s (KahanSum a) = MV_KahanSum (U.MVector s (a, a)) +instance U.Unbox a => MVector U.MVector (KahanSum a) where {-# INLINE GM.basicLength #-} {-# INLINE GM.basicUnsafeSlice #-} {-# INLINE basicOverlaps #-} @@ -129,8 +126,8 @@ instance MVector U.MVector KahanSum where basicUnsafeMove (MV_KahanSum mvec) (MV_KahanSum mvec') = basicUnsafeMove mvec mvec' basicUnsafeGrow (MV_KahanSum mvec) len = MV_KahanSum `liftM` basicUnsafeGrow mvec len -newtype instance U.Vector KahanSum = V_KahanSum (U.Vector (Double, Double)) -instance Vector U.Vector KahanSum where +newtype instance U.Vector (KahanSum a) = V_KahanSum (U.Vector (a, a)) +instance U.Unbox a => Vector U.Vector (KahanSum a) where {-# INLINE basicUnsafeFreeze #-} {-# INLINE basicUnsafeThaw #-} {-# INLINE G.basicLength #-} @@ -147,43 +144,43 @@ instance Vector U.Vector KahanSum where elemseq (V_KahanSum vec) val = elemseq vec ((\ (KahanSum a b) -> (a, b)) val) -instance Summation KahanSum where +instance RealFloat a => Summation KahanSum a where zero = KahanSum 0 0 add = kahanAdd -instance NFData KahanSum where +instance NFData (KahanSum a) where rnf !_ = () -- | @since 0.3.0.0 -instance Monoid KahanSum where +instance RealFloat a => Monoid (KahanSum a) where mempty = zero s `mappend` KahanSum s' _ = add s s' #if MIN_VERSION_base(4,9,0) -- | @since 0.3.0.0 -instance Semigroup KahanSum where +instance RealFloat a => Semigroup (KahanSum a) where (<>) = mappend #endif -kahanAdd :: KahanSum -> Double -> KahanSum +kahanAdd :: RealFloat a => KahanSum a -> a -> KahanSum a kahanAdd (KahanSum sum c) x = KahanSum sum' c' where sum' = sum + y c' = (sum' - sum) - y y = x - c -- | Return the result of a Kahan sum. -kahan :: KahanSum -> Double +kahan :: KahanSum a -> a kahan (KahanSum sum _) = sum -- | Kahan-Babuška-Neumaier summation. This is a little more -- computationally costly than plain Kahan summation, but is /always/ -- at least as accurate. -data KBNSum = KBNSum {-# UNPACK #-} !Double {-# UNPACK #-} !Double +data KBNSum a = KBNSum !a !a deriving (Eq, Show, Typeable, Data) -instance U.Unbox KBNSum -newtype instance U.MVector s KBNSum = MV_KBNSum (U.MVector s (Double, Double)) -instance MVector U.MVector KBNSum where +instance U.Unbox a => U.Unbox (KBNSum a) +newtype instance U.MVector s (KBNSum a) = MV_KBNSum (U.MVector s (a, a)) +instance U.Unbox a => MVector U.MVector (KBNSum a) where {-# INLINE GM.basicLength #-} {-# INLINE GM.basicUnsafeSlice #-} {-# INLINE basicOverlaps #-} @@ -211,8 +208,8 @@ instance MVector U.MVector KBNSum where basicUnsafeMove (MV_KBNSum mvec) (MV_KBNSum mvec') = basicUnsafeMove mvec mvec' basicUnsafeGrow (MV_KBNSum mvec) len = MV_KBNSum `liftM` basicUnsafeGrow mvec len -newtype instance U.Vector KBNSum = V_KBNSum (U.Vector (Double, Double)) -instance Vector U.Vector KBNSum where +newtype instance U.Vector (KBNSum a) = V_KBNSum (U.Vector (a, a)) +instance U.Unbox a => Vector U.Vector (KBNSum a) where {-# INLINE basicUnsafeFreeze #-} {-# INLINE basicUnsafeThaw #-} {-# INLINE G.basicLength #-} @@ -229,32 +226,32 @@ instance Vector U.Vector KBNSum where elemseq (V_KBNSum vec) val = elemseq vec ((\ (KBNSum a b) -> (a, b)) val) -instance Summation KBNSum where +instance RealFloat a => Summation KBNSum a where zero = KBNSum 0 0 add = kbnAdd -instance NFData KBNSum where +instance NFData (KBNSum a) where rnf !_ = () -- | @since 0.3.0.0 -instance Monoid KBNSum where +instance RealFloat a => Monoid (KBNSum a) where mempty = zero s `mappend` KBNSum s' c' = add (add s s') c' #if MIN_VERSION_base(4,9,0) -- | @since 0.3.0.0 -instance Semigroup KBNSum where +instance RealFloat a => Semigroup (KBNSum a) where (<>) = mappend #endif -kbnAdd :: KBNSum -> Double -> KBNSum +kbnAdd :: (Num a, Ord a) => KBNSum a -> a -> KBNSum a kbnAdd (KBNSum sum c) x = KBNSum sum' c' where c' | abs sum >= abs x = c + ((sum - sum') + x) | otherwise = c + ((x - sum') + sum) sum' = sum + x -- | Return the result of a Kahan-Babuška-Neumaier sum. -kbn :: KBNSum -> Double +kbn :: Num a => KBNSum a -> a kbn (KBNSum sum c) = sum + c -- | Second-order Kahan-Babuška summation. This is more @@ -265,14 +262,12 @@ kbn (KBNSum sum c) = sum + c -- This method compensates for error in both the sum and the -- first-order compensation term, hence the use of \"second order\" in -- the name. -data KB2Sum = KB2Sum {-# UNPACK #-} !Double - {-# UNPACK #-} !Double - {-# UNPACK #-} !Double +data KB2Sum a = KB2Sum !a !a !a deriving (Eq, Show, Typeable, Data) -instance U.Unbox KB2Sum -newtype instance U.MVector s KB2Sum = MV_KB2Sum (U.MVector s (Double, Double, Double)) -instance MVector U.MVector KB2Sum where +instance U.Unbox a => U.Unbox (KB2Sum a) +newtype instance U.MVector s (KB2Sum a) = MV_KB2Sum (U.MVector s (a, a, a)) +instance U.Unbox a => MVector U.MVector (KB2Sum a) where {-# INLINE GM.basicLength #-} {-# INLINE GM.basicUnsafeSlice #-} {-# INLINE basicOverlaps #-} @@ -300,8 +295,8 @@ instance MVector U.MVector KB2Sum where basicUnsafeMove (MV_KB2Sum mvec) (MV_KB2Sum mvec') = basicUnsafeMove mvec mvec' basicUnsafeGrow (MV_KB2Sum mvec) len = MV_KB2Sum `liftM` basicUnsafeGrow mvec len -newtype instance U.Vector KB2Sum = V_KB2Sum (U.Vector (Double, Double, Double)) -instance Vector U.Vector KB2Sum where +newtype instance U.Vector (KB2Sum a) = V_KB2Sum (U.Vector (a, a, a)) +instance U.Unbox a => Vector U.Vector (KB2Sum a) where {-# INLINE basicUnsafeFreeze #-} {-# INLINE basicUnsafeThaw #-} {-# INLINE G.basicLength #-} @@ -317,26 +312,26 @@ instance Vector U.Vector KB2Sum where basicUnsafeCopy (MV_KB2Sum mvec) (V_KB2Sum vec) = G.basicUnsafeCopy mvec vec elemseq (V_KB2Sum vec) val = elemseq vec ((\ (KB2Sum a b c) -> (a, b, c)) val) -instance Summation KB2Sum where +instance RealFloat a => Summation KB2Sum a where zero = KB2Sum 0 0 0 add = kb2Add -instance NFData KB2Sum where +instance NFData (KB2Sum a) where rnf !_ = () -- | @since 0.3.0.0 -instance Monoid KB2Sum where +instance RealFloat a => Monoid (KB2Sum a) where mempty = zero s `mappend` KB2Sum s' c' cc' = add (add (add s s') c') cc' #if MIN_VERSION_base(4,9,0) -- | @since 0.3.0.0 -instance Semigroup KB2Sum where +instance RealFloat a => Semigroup (KB2Sum a) where (<>) = mappend #endif -kb2Add :: KB2Sum -> Double -> KB2Sum +kb2Add :: (Num a, Ord a) => KB2Sum a -> a -> KB2Sum a kb2Add (KB2Sum sum c cc) x = KB2Sum sum' c' cc' where sum' = sum + x c' = c + k @@ -346,12 +341,11 @@ kb2Add (KB2Sum sum c cc) x = KB2Sum sum' c' cc' | otherwise = (x - sum') + sum -- | Return the result of an order-2 Kahan-Babuška sum. -kb2 :: KB2Sum -> Double +kb2 :: Num a => KB2Sum a -> a kb2 (KB2Sum sum c cc) = sum + c + cc -- | /O(n)/ Sum a vector of values. -sumVector :: (Vector v Double, Summation s) => - (s -> Double) -> v Double -> Double +sumVector :: RealFloat a => (Vector v a, Summation s a) => (s a -> a) -> v a -> a sumVector f = f . foldl' add zero {-# INLINE sumVector #-} @@ -361,7 +355,7 @@ sumVector f = f . foldl' add zero -- bounds on its error growth. Instead of having roughly constant -- error regardless of the size of the input vector, in the worst case -- its accumulated error grows with /O(log n)/. -pairwiseSum :: (Vector v Double) => v Double -> Double +pairwiseSum :: RealFloat a => (Vector v a) => v a -> a pairwiseSum v | len <= 256 = G.sum v | otherwise = uncurry (+) . (pairwiseSum *** pairwiseSum) . @@ -383,7 +377,7 @@ pairwiseSum v -- computes the sum of elements in a list. -- -- @ --- sillySumList :: [Double] -> Double +-- sillySumList :: RealFloat a => [a] -> a -- sillySumList = loop 'zero' -- where loop s [] = 'kbn' s -- loop s (x:xs) = 'seq' s' loop s' xs @@ -397,7 +391,7 @@ pairwiseSum v -- -- Avoid ambiguity around which sum function we are using. -- import Prelude hiding (sum) -- -- --- betterSumList :: [Double] -> Double +-- betterSumList :: RealFloat a => [a] -> a -- betterSumList xs = 'Numeric.Sum.sum' 'kbn' xs -- @ @@ -410,7 +404,7 @@ pairwiseSum v -- intermediate values are as accurate as possible. -- -- @ --- prefixSum :: [Double] -> [Double] +-- prefixSum :: RealFloat a => [a] -> [a] -- prefixSum xs = map 'kbn' . 'scanl' 'add' 'zero' $ xs -- @ diff --git a/tests/Tests/Sum.hs b/tests/Tests/Sum.hs index 08eaf1e..c347b0a 100644 --- a/tests/Tests/Sum.hs +++ b/tests/Tests/Sum.hs @@ -1,8 +1,13 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Tests.Sum (tests) where import Control.Applicative ((<$>)) +import Foreign.C.Types import Numeric.Sum as Sum import Prelude hiding (sum) import Test.Tasty (TestTree, testGroup) @@ -10,20 +15,20 @@ import Test.Tasty.QuickCheck (testProperty) import Test.QuickCheck (Arbitrary(..)) import qualified Prelude -t_sum :: ([Double] -> Double) -> [Double] -> Bool +t_sum :: RealFloat a => ([a] -> a) -> [a] -> Bool t_sum f xs = f xs == trueSum xs -t_sum_error :: ([Double] -> Double) -> [Double] -> Bool +t_sum_error :: RealFloat a => ([a] -> a) -> [a] -> Bool t_sum_error f xs = abs (ts - f xs) <= abs (ts - Prelude.sum xs) where ts = trueSum xs -t_sum_shifted :: ([Double] -> Double) -> [Double] -> Bool +t_sum_shifted :: RealFloat a => ([a] -> a) -> [a] -> Bool t_sum_shifted f = t_sum_error f . zipWith (+) badvec trueSum :: (Fractional b, Real a) => [a] -> b trueSum xs = fromRational . Prelude.sum . map toRational $ xs -badvec :: [Double] +badvec :: RealFloat a => [a] badvec = cycle [1,1e16,-1e16] tests :: TestTree @@ -36,52 +41,76 @@ tests = testGroup "Summation" [ -- testProperty "t_sum_error" $ t_sum_error (sum id) -- testProperty "t_sum_shifted" $ t_sum_shifted (sum id) ] - , testGroup "Kahan" [ + , testGroup "Kahan" $ testShifted kahan -- tests that cannot pass: - -- testProprty "t_sum" $ t_sum (sum kahan) + -- testProperty "t_sum" $ t_sum (sum kahan) -- testProperty "t_sum_error" $ t_sum_error (sum kahan) -- kahan summation only beats normal summation with large values - testProperty "t_sum_shifted" $ t_sum_shifted (sum kahan) - ] - , testGroup "KBN" [ - testProperty "t_sum" $ t_sum (sum kbn) - , testProperty "t_sum_error" $ t_sum_error (sum kbn) - , testProperty "t_sum_shifted" $ t_sum_shifted (sum kbn) - ] - , testGroup "KB2" [ - testProperty "t_sum" $ t_sum (sum kb2) - , testProperty "t_sum_error" $ t_sum_error (sum kb2) - , testProperty "t_sum_shifted" $ t_sum_shifted (sum kb2) - ] + , testGroup "KBN" $ testSum kbn + , testGroup "KB2" $ testSum kb2 + ] + +type SummationTestTypes s = + ( Summation s Float + , Summation s Double + , Summation s CFloat + , Summation s CDouble + ) + +testShifted :: forall s. SummationTestTypes s + => (forall a. Summation s a => s a -> a) + -> [TestTree] +testShifted f = testOnTypes f [ ("t_sum_shifted", t_sum_shifted) ] + +testSum :: forall s. SummationTestTypes s + => (forall a. Summation s a => s a -> a) + -> [TestTree] +testSum f = testOnTypes f + [ ("t_sum", t_sum) + , ("t_sum_error", t_sum_error) + , ("t_sum_shifted", t_sum_shifted) + ] + +testOnTypes :: forall s. SummationTestTypes s + => (forall a. Summation s a => s a -> a) + -> (forall a. Summation s a => [ (String, ([a] -> a) -> [a] -> Bool) ]) + -> [TestTree] +testOnTypes f ts = + [ testGroup "Float" $ toTest (f :: s Float -> Float) <$> ts + , testGroup "Double" $ toTest (f :: s Double -> Double) <$> ts + , testGroup "CFloat" $ toTest (f :: s CFloat -> CFloat) <$> ts + , testGroup "CDouble" $ toTest (f :: s CDouble -> CDouble) <$> ts ] + where + toTest f' (testName, test) = testProperty testName $ test (sum f') -instance Arbitrary KahanSum where +instance Arbitrary a => Arbitrary (KahanSum a) where arbitrary = toKahan <$> arbitrary shrink = map toKahan . shrink . fromKahan -toKahan :: (Double, Double) -> KahanSum +toKahan :: (a, a) -> KahanSum a toKahan (a,b) = KahanSum a b -fromKahan :: KahanSum -> (Double, Double) +fromKahan :: KahanSum a -> (a, a) fromKahan (KahanSum a b) = (a,b) -instance Arbitrary KBNSum where +instance Arbitrary a => Arbitrary (KBNSum a) where arbitrary = toKBN <$> arbitrary shrink = map toKBN . shrink . fromKBN -toKBN :: (Double, Double) -> KBNSum +toKBN :: (a, a) -> KBNSum a toKBN (a,b) = KBNSum a b -fromKBN :: KBNSum -> (Double, Double) +fromKBN :: KBNSum a -> (a, a) fromKBN (KBNSum a b) = (a,b) -instance Arbitrary KB2Sum where +instance Arbitrary a => Arbitrary (KB2Sum a) where arbitrary = toKB2 <$> arbitrary shrink = map toKB2 . shrink . fromKB2 -toKB2 :: (Double, Double, Double) -> KB2Sum +toKB2 :: (a, a, a) -> KB2Sum a toKB2 (a,b,c) = KB2Sum a b c -fromKB2 :: KB2Sum -> (Double, Double, Double) +fromKB2 :: KB2Sum a -> (a, a, a) fromKB2 (KB2Sum a b c) = (a,b,c)