Skip to content

Commit ed45645

Browse files
More sensible backPropagation implementation
1 parent 8a90cf2 commit ed45645

File tree

5 files changed

+72
-54
lines changed

5 files changed

+72
-54
lines changed

examples/Constrained/Examples/Basic.hs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,23 @@ manyInconsistentTrans = constrained' $ \ [var| a |] [var| b |] c d e [var| f |]
343343
, assert $ f >. 10
344344
, assert $ f <. b
345345
]
346+
347+
complicatedEither :: Specification (Either Int Int, (Either Int Int, Int, Int))
348+
complicatedEither = constrained' $ \ [var| i |] [var| t |] ->
349+
[ caseOn i
350+
(branch $ \ a -> a `elem_` lit [1..10])
351+
(branch $ \ b -> b `elem_` lit [1..10])
352+
, match t $ \ [var| k |] _ _ ->
353+
[ k ==. i
354+
, not_ $ k `elem_` lit [ Left j | j <- [1..9] ]
355+
]
356+
]
357+
358+
pairCant :: Specification (Int, (Int, Int))
359+
pairCant = constrained' $ \ [var| i |] [var| p |] ->
360+
[ assert $ i `elem_` lit [1..10]
361+
, match p $ \ [var| k |] _ ->
362+
[ k ==. i
363+
, not_ $ k `elem_` lit [1..9]
364+
]
365+
]

src/Constrained/Base.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ module Constrained.Base (
3232
pattern (:<:),
3333
pattern (:>:),
3434
pattern Unary,
35-
Ctx,
35+
Ctx(..),
3636
toCtx,
3737
flipCtx,
3838
fromListCtx,
39+
ctxHasSpec,
3940

4041
-- * Useful function symbols and patterns for building custom rewrite rules
4142
fromGeneric_,

src/Constrained/Generation.hs

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ module Constrained.Generation (
4949
EqW (..),
5050
SumSpec (..),
5151
pattern SumSpec,
52+
53+
mapSpec,
54+
forwardPropagateSpec,
5255
) where
5356

5457
import Constrained.AbstractSyntax
@@ -833,17 +836,7 @@ mergeSolverStage (SolverStage x ps spec relevant) plan =
833836
normalizeSolverStage $ SolverStage
834837
y
835838
(ps ++ ps')
836-
( addToErrorSpec
837-
( NE.fromList
838-
( [ "Solving var " ++ show x ++ " fails."
839-
, "Merging the Specs"
840-
, " 1. " ++ show spec
841-
, " 2. " ++ show spec'
842-
]
843-
)
844-
)
845-
(spec <> spec')
846-
)
839+
(spec <> spec')
847840
(relevant <> relevant')
848841
Nothing -> stage
849842
| stage@(SolverStage y ps' spec' relevant') <- plan
@@ -897,27 +890,21 @@ backPropagation relevant (SolverPlan initplan) = SolverPlan (go [] (reverse init
897890
go acc [] = acc
898891
go acc (s@(SolverStage (x :: Var a) ps spec _) : plan) = go (s : acc) plan'
899892
where
900-
newStages = concatMap (newStage spec) ps
893+
newStages = concatMap newStage ps
901894
plan' = foldr mergeSolverStage plan newStages
895+
902896
-- Note use of the Term Pattern Equal
903-
newStage specl (Assert (Equal (V x') t)) =
904-
termVarEqCases specl x' t
905-
newStage specr (Assert (Equal t (V x'))) =
906-
termVarEqCases specr x' t
907-
newStage _ _ = []
908-
909-
termVarEqCases :: HasSpec b => Specification a -> Var b -> Term b -> [SolverStage]
910-
termVarEqCases (MemberSpec vs) x' t
911-
| Set.singleton (Name x) == freeVarSet t =
912-
[SolverStage x' [] (MemberSpec (NE.nub (fmap (\v -> errorGE $ runTerm (Env.singleton x v) t) vs)))
913-
(Set.insert (Name x') relevant)]
914-
termVarEqCases specx x' t
915-
| Just Refl <- eqVar x x'
916-
, [Name y] <- Set.toList $ freeVarSet t
917-
, Result ctx <- toCtx y t =
918-
[SolverStage y [] (propagateSpec specx ctx)
919-
(Set.insert (Name x') relevant)]
920-
termVarEqCases _ _ _ = []
897+
newStage (Assert (Equal tl tr))
898+
| [Name xl] <- Set.toList $ freeVarSet tl
899+
, [Name xr] <- Set.toList $ freeVarSet tr
900+
, Name x `elem` [Name xl, Name xr]
901+
, Result ctxL <- toCtx xl tl
902+
, Result ctxR <- toCtx xr tr
903+
= case (eqVar x xl, eqVar x xr) of
904+
(Just Refl, _) -> [SolverStage xr [] (propagateSpec (forwardPropagateSpec spec ctxL) ctxR) (Set.insert (Name x) relevant)]
905+
(_, Just Refl) -> [SolverStage xl [] (propagateSpec (forwardPropagateSpec spec ctxR) ctxL) (Set.insert (Name x) relevant)]
906+
_ -> error "The impossible happened"
907+
newStage _ = []
921908

922909
-- | Function symbols for `(==.)`
923910
data EqW :: [Type] -> Type -> Type where
@@ -1329,3 +1316,34 @@ fromGESpec ge = case ge of
13291316
Result s -> s
13301317
GenError xs -> ErrorSpec (catMessageList xs)
13311318
FatalError es -> error $ catMessages es
1319+
1320+
-- TODO: move this somewhere sensible
1321+
1322+
-- | Functor like property for Specification, but instead of a Haskell function (a -> b),
1323+
-- it takes a function symbol (t '[a] b) from a to b.
1324+
-- Note, in this context, a function symbol is some constructor of a witnesstype.
1325+
-- Eg. ProdFstW, InjRightW, SingletonW, etc. NOT the lifted versions like fst_ singleton_,
1326+
-- which construct Terms. We had to wait until here to define this because it
1327+
-- depends on Semigroup property of Specification, and Asserting equality
1328+
mapSpec ::
1329+
forall t a b.
1330+
AppRequires t '[a] b =>
1331+
t '[a] b ->
1332+
Specification a ->
1333+
Specification b
1334+
mapSpec f (ExplainSpec es s) = explainSpec es (mapSpec f s)
1335+
mapSpec f TrueSpec = mapTypeSpec f (emptySpec @a)
1336+
mapSpec _ (ErrorSpec err) = ErrorSpec err
1337+
mapSpec f (MemberSpec as) = MemberSpec $ NE.nub $ fmap (semantics f) as
1338+
mapSpec f (SuspendedSpec x p) =
1339+
constrained $ \x' ->
1340+
Exists (\_ -> fatalError "mapSpec") (x :-> fold [Assert $ (x' ==. appTerm f (V x)), p])
1341+
mapSpec f (TypeSpec ts cant) = mapTypeSpec f ts <> notMemberSpec (map (semantics f) cant)
1342+
1343+
-- TODO generalizeme!
1344+
forwardPropagateSpec :: HasSpec a => Specification a -> Ctx a b -> Specification b
1345+
forwardPropagateSpec s CtxHOLE = s
1346+
forwardPropagateSpec s (CtxApp f (c :? Nil))
1347+
| Evidence <- ctxHasSpec c = mapSpec f (forwardPropagateSpec s c)
1348+
forwardPropagateSpec _ _ = TrueSpec
1349+

src/Constrained/TheKnot.hs

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ module Constrained.TheKnot (
5151
rangeSize,
5252
hasSize,
5353
genInverse,
54-
mapSpec,
5554
between,
5655

5756
-- * Patterns
@@ -77,7 +76,6 @@ import Constrained.SumList
7776
-- Because it is mutually recursive with something else in here.
7877
import Constrained.Syntax
7978
import Control.Applicative
80-
import Control.Monad
8179
import Data.Foldable
8280
import Data.Kind
8381
import Data.List (nub)
@@ -99,27 +97,6 @@ ifElse b p q = whenTrue b p <> whenTrue (not_ b) q
9997

10098
-- =======================================================================================
10199

102-
-- | Functor like property for Specification, but instead of a Haskell function (a -> b),
103-
-- it takes a function symbol (t '[a] b) from a to b.
104-
-- Note, in this context, a function symbol is some constructor of a witnesstype.
105-
-- Eg. ProdFstW, InjRightW, SingletonW, etc. NOT the lifted versions like fst_ singleton_,
106-
-- which construct Terms. We had to wait until here to define this because it
107-
-- depends on Semigroup property of Specification, and Asserting equality
108-
mapSpec ::
109-
forall t a b.
110-
AppRequires t '[a] b =>
111-
t '[a] b ->
112-
Specification a ->
113-
Specification b
114-
mapSpec f (ExplainSpec es s) = explainSpec es (mapSpec f s)
115-
mapSpec f TrueSpec = mapTypeSpec f (emptySpec @a)
116-
mapSpec _ (ErrorSpec err) = ErrorSpec err
117-
mapSpec f (MemberSpec as) = MemberSpec $ NE.nub $ fmap (semantics f) as
118-
mapSpec f (SuspendedSpec x p) =
119-
constrained $ \x' ->
120-
Exists (\_ -> fatalError "mapSpec") (x :-> fold [Assert $ (x' ==. appTerm f (V x)), p])
121-
mapSpec f (TypeSpec ts cant) = mapTypeSpec f ts <> notMemberSpec (map (semantics f) cant)
122-
123100
-- ================================================================
124101
-- HasSpec for Products
125102
-- ================================================================

test/Constrained/Tests.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ testAll = hspec $ tests False
5959
tests :: Bool -> Spec
6060
tests nightly =
6161
describe "constrained" . modifyMaxSuccess (\ms -> if nightly then ms * 10 else ms) $ do
62+
testSpec "complicatedEither" complicatedEither
63+
testSpec "pairCatn" pairCant
6264
-- TODO: double-shrinking
6365
testSpecNoShrink "reifiesMultiple" reifiesMultiple
6466
testSpec "assertReal" assertReal

0 commit comments

Comments
 (0)