Skip to content

Commit b0afaa9

Browse files
WIP: Complete Num instance
1 parent 9cc97e0 commit b0afaa9

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

src/Constrained/NumOrd.hs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ cardinalNumSpec (NumSpecInterval Nothing Nothing) = cardinalTrueSpec @n
656656
-- Now the operations on Numbers
657657

658658
-- | Everything we need to make the number operations make sense on a given type
659-
class (Num a, HasSpec a, HasDivision a) => NumLike a where
659+
class (Num a, HasSpec a, HasDivision a, OrdLike a) => NumLike a where
660660
subtractSpec :: a -> TypeSpec a -> Specification a
661661
default subtractSpec ::
662662
( NumLike (SimpleRep a)
@@ -689,23 +689,31 @@ data IntW (as :: [Type]) b where
689689
AddW :: NumLike a => IntW '[a, a] a
690690
MultW :: NumLike a => IntW '[a, a] a
691691
NegateW :: NumLike a => IntW '[a] a
692+
AbsW :: NumLike a => IntW '[a] a
693+
SignumW :: NumLike a => IntW '[a] a
692694

693695
deriving instance Eq (IntW dom rng)
694696

695697
instance Show (IntW d r) where
696698
show AddW = "+"
697699
show NegateW = "negate_"
698700
show MultW = "*"
701+
show AbsW = "abs_"
702+
show SignumW = "signum_"
699703

700704
instance Semantics IntW where
701705
semantics AddW = (+)
702706
semantics NegateW = negate
703707
semantics MultW = (*)
708+
semantics AbsW = abs
709+
semantics SignumW = signum
704710

705711
instance Syntax IntW where
706712
isInfix AddW = True
707713
isInfix NegateW = False
708714
isInfix MultW = True
715+
isInfix AbsW = False
716+
isInfix SignumW = False
709717

710718
class HasDivision a where
711719
doDivide :: a -> a -> a
@@ -860,8 +868,8 @@ instance NumLike a => Num (Term a) where
860868
negate = negate_
861869
fromInteger = Lit . fromInteger
862870
(*) = (*.)
863-
abs = error "abs not implemented for Term Fn Int"
864-
signum = error "signum not implemented for Term Fn Int"
871+
abs = abs_
872+
signum = signum_
865873

866874
invertMult :: (HasSpec a, Num a, HasDivision a) => a -> a -> Maybe a
867875
invertMult a b =
@@ -879,6 +887,13 @@ instance Logic IntW where
879887
| otherwise = ErrorSpec $ NE.fromList [ "zero" ]
880888
propagateTypeSpec MultW (HOLE :<: i) ts cant = divideSpec i ts <> notMemberSpec (mapMaybe (flip invertMult i) cant)
881889
propagateTypeSpec MultW ctx ts cant = propagateTypeSpec MultW (flipCtx ctx) ts cant
890+
propagateTypeSpec AbsW (Unary HOLE) ts cant = error "TODO"
891+
propagateTypeSpec SignumW (Unary HOLE) ts cant =
892+
constrained $ \ x ->
893+
[ x `satisfies` notMemberSpec [0] | not $ ok 0 ] ++
894+
[ Assert $ 0 <=. x | not $ ok (-1) ] ++
895+
[ Assert $ x <=. x | not $ ok 1 ]
896+
where ok = flip conformsToSpec (TypeSpec ts cant)
882897

883898
propagateMemberSpec AddW (HOLE :<: i) es =
884899
memberSpec
@@ -896,6 +911,15 @@ instance Logic IntW where
896911
| otherwise = ErrorSpec $ NE.fromList [ "zero" ]
897912
propagateMemberSpec MultW (HOLE :<: i) es = memberSpec (mapMaybe (flip invertMult i) (NE.toList es)) (NE.fromList ["propagateSpec"])
898913
propagateMemberSpec MultW ctx es = propagateMemberSpec MultW (flipCtx ctx) es
914+
propagateMemberSpec AbsW (Unary HOLE) es
915+
| all ((== -1) . signum) es = ErrorSpec $ NE.fromList [ "abs for all negative member spec", show es ]
916+
| otherwise = MemberSpec $ NE.nub . NE.fromList $ concat $ [ [e, negate e] | e <- NE.toList es, signum e /= -1 ]
917+
propagateMemberSpec SignumW (Unary HOLE) es
918+
| all ((`notElem` [-1, 0, 1]) . signum) es = ErrorSpec $ NE.fromList [ "signum for invalid member spec", show es ]
919+
| otherwise = constrained $ \ x ->
920+
[ x `satisfies` notMemberSpec [0] | 0 `notElem` es ] ++
921+
[ Assert $ 0 <=. x | -1 `notElem` es ] ++
922+
[ Assert $ x <=. x | 1 `notElem` es ]
899923

900924
rewriteRules AddW (x :> y :> Nil) _ | x == y = Just $ 2 * x
901925
rewriteRules _ _ _ = Nothing
@@ -916,6 +940,14 @@ infixl 7 *.
916940
negate_ :: NumLike a => Term a -> Term a
917941
negate_ = appTerm NegateW
918942

943+
-- | `Term`-level `abs`
944+
abs_ :: NumLike a => Term a -> Term a
945+
abs_ = appTerm AbsW
946+
947+
-- | `Term`-level `signum`
948+
signum_ :: NumLike a => Term a -> Term a
949+
signum_ = appTerm SignumW
950+
919951
infix 4 -.
920952

921953
-- | `Term`-level `(-)`

0 commit comments

Comments
 (0)