@@ -28,6 +28,7 @@ module Constrained.NumOrd (
2828 (>=.) ,
2929 (<=.) ,
3030 (+.) ,
31+ (*.) ,
3132 negate_ ,
3233 cardinality ,
3334 caseBoolSpec ,
@@ -202,6 +203,8 @@ deriving via Unbounded (Ratio Integer) instance MaybeBounded (Ratio Integer)
202203
203204deriving via Unbounded Float instance MaybeBounded Float
204205
206+ deriving via Unbounded Double instance MaybeBounded Double
207+
205208instance MaybeBounded Natural where
206209 lowerBound = Just 0
207210 upperBound = Nothing
@@ -651,7 +654,7 @@ cardinalNumSpec (NumSpecInterval Nothing Nothing) = cardinalTrueSpec @n
651654-- Now the operations on Numbers
652655
653656-- | Everything we need to make the number operations make sense on a given type
654- class (Num a , HasSpec a ) => NumLike a where
657+ class (Num a , HasSpec a , HasDivision a ) => NumLike a where
655658 subtractSpec :: a -> TypeSpec a -> Specification a
656659 default subtractSpec ::
657660 ( NumLike (SimpleRep a)
@@ -682,24 +685,132 @@ class (Num a, HasSpec a) => NumLike a where
682685-- | Operations on numbers
683686data IntW (as :: [Type ]) b where
684687 AddW :: NumLike a => IntW '[a , a ] a
688+ MultW :: NumLike a => IntW '[a , a ] a
685689 NegateW :: NumLike a => IntW '[a ] a
686690
687691deriving instance Eq (IntW dom rng )
688692
689693instance Show (IntW d r ) where
690694 show AddW = " +"
691695 show NegateW = " negate_"
696+ show MultW = " *"
692697
693698instance Semantics IntW where
694699 semantics AddW = (+)
695700 semantics NegateW = negate
701+ semantics MultW = (*)
696702
697703instance Syntax IntW where
698704 isInfix AddW = True
699705 isInfix NegateW = False
706+ isInfix MultW = True
707+
708+ class HasDivision a where
709+ doDivide :: a -> a -> a
710+ default doDivide ::
711+ ( HasDivision (SimpleRep a)
712+ , GenericRequires a
713+ ) =>
714+ a ->
715+ a ->
716+ a
717+ doDivide a b = fromSimpleRep $ doDivide (toSimpleRep a) (toSimpleRep b)
718+
719+ divideSpec :: a -> TypeSpec a -> Specification a
720+ default divideSpec ::
721+ ( HasDivision (SimpleRep a)
722+ , GenericRequires a
723+ ) =>
724+ a ->
725+ TypeSpec a ->
726+ Specification a
727+ divideSpec a ts = fromSimpleRepSpec $ divideSpec (toSimpleRep a) ts
728+
729+ instance {-# OVERLAPPABLE #-} (HasSpec a , MaybeBounded a , Integral a , TypeSpec a ~ NumSpec a ) => HasDivision a where
730+ doDivide = div
731+
732+ divideSpec 0 _ = TrueSpec
733+ divideSpec a (NumSpecInterval (unionWithMaybe max lowerBound -> ml) (unionWithMaybe min upperBound -> mu)) = typeSpec ts
734+ where
735+ ts | a > 0 = NumSpecInterval ml' mu'
736+ | otherwise = NumSpecInterval mu' ml'
737+ ml' = adjustLowerBound <$> ml
738+ mu' = adjustUpperBound <$> mu
739+
740+ -- NOTE: negate has different overflow semantics than div, so that's why we use negate below...
741+
742+ adjustLowerBound l
743+ | a == 1 = l
744+ | a == - 1 = negate l
745+ | otherwise =
746+ let r = l `div` a in
747+ if toInteger r * toInteger a < toInteger l
748+ then r + signum a
749+ else r
750+
751+ adjustUpperBound u
752+ | a == 1 = u
753+ | a == - 1 = negate u
754+ | otherwise =
755+ let r = u `div` a in
756+ if toInteger r * toInteger a > toInteger u
757+ then r - signum a
758+ else r
759+
760+ instance HasDivision Float where
761+ doDivide = (/)
762+
763+ divideSpec 0 _ = TrueSpec
764+ divideSpec a (NumSpecInterval ml mu) = typeSpec ts
765+ where
766+ ts | a > 0 = NumSpecInterval ml' mu'
767+ | otherwise = NumSpecInterval mu' ml'
768+ ml' = adjustLowerBound <$> ml
769+ mu' = adjustUpperBound <$> mu
770+ adjustLowerBound l =
771+ let r = l / a
772+ l' = r * a
773+ in
774+ if l' < l
775+ then r + (l - l') * 2 / a
776+ else r
777+
778+ adjustUpperBound u =
779+ let r = u / a
780+ u' = r * a
781+ in
782+ if u < u'
783+ then r - (u' - u) * 2 / a
784+ else r
785+
786+ instance HasDivision Double where
787+ doDivide = (/)
788+
789+ divideSpec 0 _ = TrueSpec
790+ divideSpec a (NumSpecInterval ml mu) = typeSpec ts
791+ where
792+ ts | a > 0 = NumSpecInterval ml' mu'
793+ | otherwise = NumSpecInterval mu' ml'
794+ ml' = adjustLowerBound <$> ml
795+ mu' = adjustUpperBound <$> mu
796+ adjustLowerBound l =
797+ let r = l / a
798+ l' = r * a
799+ in
800+ if l' < l
801+ then r + (l - l') * 2 / a
802+ else r
803+
804+ adjustUpperBound u =
805+ let r = u / a
806+ u' = r * a
807+ in
808+ if u < u'
809+ then r - (u' - u) * 2 / a
810+ else r
700811
701812-- | A type that we can reason numerically about in constraints
702- type Numeric a = (HasSpec a , Ord a , Num a , TypeSpec a ~ NumSpec a , MaybeBounded a )
813+ type Numeric a = (HasSpec a , Ord a , Num a , TypeSpec a ~ NumSpec a , MaybeBounded a , HasDivision a )
703814
704815instance {-# OVERLAPPABLE #-} Numeric a => NumLike a where
705816 subtractSpec a ts@ (NumSpecInterval ml mu)
@@ -728,6 +839,7 @@ instance {-# OVERLAPPABLE #-} Numeric a => NumLike a where
728839 | Just r <- safeSubtract a1 x = r
729840 | a1 < 0 = fromJust upperBound
730841 | otherwise = fromJust lowerBound
842+
731843 negateSpec (NumSpecInterval ml mu) = typeSpec $ NumSpecInterval (negate <$> mu) (negate <$> ml)
732844
733845 safeSubtract a x
@@ -742,20 +854,29 @@ instance {-# OVERLAPPABLE #-} Numeric a => NumLike a where
742854 | otherwise = Just $ x - a
743855
744856instance NumLike a => Num (Term a ) where
745- (+) = addFn
746- negate = negateFn
857+ (+) = (+.)
858+ negate = negate_
747859 fromInteger = Lit . fromInteger
748- (*) = error " (*) not implemented for Term Fn Int "
860+ (*) = (*.)
749861 abs = error " abs not implemented for Term Fn Int"
750862 signum = error " signum not implemented for Term Fn Int"
751863
864+ invertMult :: (HasSpec a , Num a , HasDivision a ) => a -> a -> Maybe a
865+ invertMult a b =
866+ let r = a `doDivide` b in if r * b == a then Just r else Nothing
867+
752868-- | Just a note that these instances won't work until we are in a context where
753869-- there is a HasSpec instance of 'a', which (NumLike a) demands.
754870-- This happens in Constrained.Experiment.TheKnot
755871instance Logic IntW where
756872 propagateTypeSpec AddW (HOLE :<: i) ts cant = subtractSpec i ts <> notMemberSpec (mapMaybe (safeSubtract i) cant)
757873 propagateTypeSpec AddW ctx ts cant = propagateTypeSpec AddW (flipCtx ctx) ts cant
758874 propagateTypeSpec NegateW (Unary HOLE ) ts cant = negateSpec ts <> notMemberSpec (map negate cant)
875+ propagateTypeSpec MultW (HOLE :<: 0 ) ts cant
876+ | 0 `conformsToSpec` TypeSpec ts cant = TrueSpec
877+ | otherwise = ErrorSpec $ NE. fromList [ " zero" ]
878+ propagateTypeSpec MultW (HOLE :<: i) ts cant = divideSpec i ts <> notMemberSpec (mapMaybe (flip invertMult i) cant)
879+ propagateTypeSpec MultW ctx ts cant = propagateTypeSpec MultW (flipCtx ctx) ts cant
759880
760881 propagateMemberSpec AddW (HOLE :<: i) es =
761882 memberSpec
@@ -768,28 +889,36 @@ instance Logic IntW where
768889 )
769890 propagateMemberSpec AddW ctx es = propagateMemberSpec AddW (flipCtx ctx) es
770891 propagateMemberSpec NegateW (Unary HOLE ) es = MemberSpec $ NE. nub $ fmap negate es
892+ propagateMemberSpec MultW (HOLE :<: 0 ) es
893+ | 0 `elem` es = TrueSpec
894+ | otherwise = ErrorSpec $ NE. fromList [ " zero" ]
895+ propagateMemberSpec MultW (HOLE :<: i) es = memberSpec (mapMaybe (flip invertMult i) (NE. toList es)) (NE. fromList [" propagateSpec" ])
896+ propagateMemberSpec MultW ctx es = propagateMemberSpec MultW (flipCtx ctx) es
771897
772- addFn :: forall a . NumLike a => Term a -> Term a -> Term a
773- addFn = appTerm AddW
774-
775- negateFn :: forall a . NumLike a => Term a -> Term a
776- negateFn = appTerm NegateW
898+ rewriteRules AddW (x :> y :> Nil ) _ | x == y = Just $ 2 * x
899+ rewriteRules _ _ _ = Nothing
777900
778901infix 4 +.
779902
780903-- | `Term`-level `(+)`
781904(+.) :: NumLike a => Term a -> Term a -> Term a
782- (+.) = addFn
905+ (+.) = appTerm AddW
906+
907+ infixl 7 *.
908+
909+ -- | `Term`-level `(+)`
910+ (*.) :: NumLike a => Term a -> Term a -> Term a
911+ (*.) = appTerm MultW
783912
784913-- | `Term`-level `negate`
785914negate_ :: NumLike a => Term a -> Term a
786- negate_ = negateFn
915+ negate_ = appTerm NegateW
787916
788917infix 4 -.
789918
790919-- | `Term`-level `(-)`
791920(-.) :: Numeric n => Term n -> Term n -> Term n
792- (-.) x y = addFn x (negateFn y)
921+ (-.) x y = x +. negate_ y
793922
794923infixr 4 <=.
795924
@@ -1029,3 +1158,14 @@ instance HasSpec Float where
10291158 toPreds = toPredsNumSpec
10301159 cardinalTypeSpec _ = TrueSpec
10311160 guardTypeSpec = guardNumSpec
1161+
1162+ instance HasSpec Double where
1163+ type TypeSpec Double = NumSpec Double
1164+ emptySpec = emptyNumSpec
1165+ combineSpec = combineNumSpec
1166+ genFromTypeSpec = genFromNumSpec
1167+ shrinkWithTypeSpec = shrinkWithNumSpec
1168+ conformsTo = conformsToNumSpec
1169+ toPreds = toPredsNumSpec
1170+ cardinalTypeSpec _ = TrueSpec
1171+ guardTypeSpec = guardNumSpec
0 commit comments