Skip to content

Commit d39bc58

Browse files
More sensible backPropagation implementation
1 parent e796646 commit d39bc58

File tree

3 files changed

+40
-28
lines changed

3 files changed

+40
-28
lines changed

src/Constrained/Base.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ module Constrained.Base (
3232
pattern (:<:),
3333
pattern (:>:),
3434
pattern Unary,
35-
Ctx,
35+
Ctx(..),
3636
toCtx,
3737
flipCtx,
3838
fromListCtx,
39+
ctxHasSpec,
3940

4041
-- * Useful function symbols and patterns for building custom rewrite rules
4142
fromGeneric_,

src/Constrained/Generation.hs

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ module Constrained.Generation (
4949
EqW (..),
5050
SumSpec (..),
5151
pattern SumSpec,
52+
53+
mapSpec,
54+
forwardPropagateSpec,
5255
) where
5356

5457
import Constrained.AbstractSyntax
@@ -907,10 +910,10 @@ backPropagation relevant (SolverPlan initplan) = SolverPlan (go [] (reverse init
907910
newStage _ _ = []
908911

909912
termVarEqCases :: HasSpec b => Specification a -> Var b -> Term b -> [SolverStage]
910-
termVarEqCases (MemberSpec vs) x' t
911-
| Set.singleton (Name x) == freeVarSet t =
912-
[SolverStage x' [] (MemberSpec (NE.nub (fmap (\v -> errorGE $ runTerm (Env.singleton x v) t) vs)))
913-
(Set.insert (Name x') relevant)]
913+
termVarEqCases specx x' t
914+
| Set.singleton (Name x) == freeVarSet t
915+
, Result ctx <- toCtx x t =
916+
[SolverStage x' [] (forwardPropagateSpec specx ctx) (Set.insert (Name x) relevant)]
914917
termVarEqCases specx x' t
915918
| Just Refl <- eqVar x x'
916919
, [Name y] <- Set.toList $ freeVarSet t
@@ -1329,3 +1332,34 @@ fromGESpec ge = case ge of
13291332
Result s -> s
13301333
GenError xs -> ErrorSpec (catMessageList xs)
13311334
FatalError es -> error $ catMessages es
1335+
1336+
-- TODO: move this somewhere sensible
1337+
1338+
-- | Functor like property for Specification, but instead of a Haskell function (a -> b),
1339+
-- it takes a function symbol (t '[a] b) from a to b.
1340+
-- Note, in this context, a function symbol is some constructor of a witnesstype.
1341+
-- Eg. ProdFstW, InjRightW, SingletonW, etc. NOT the lifted versions like fst_ singleton_,
1342+
-- which construct Terms. We had to wait until here to define this because it
1343+
-- depends on Semigroup property of Specification, and Asserting equality
1344+
mapSpec ::
1345+
forall t a b.
1346+
AppRequires t '[a] b =>
1347+
t '[a] b ->
1348+
Specification a ->
1349+
Specification b
1350+
mapSpec f (ExplainSpec es s) = explainSpec es (mapSpec f s)
1351+
mapSpec f TrueSpec = mapTypeSpec f (emptySpec @a)
1352+
mapSpec _ (ErrorSpec err) = ErrorSpec err
1353+
mapSpec f (MemberSpec as) = MemberSpec $ NE.nub $ fmap (semantics f) as
1354+
mapSpec f (SuspendedSpec x p) =
1355+
constrained $ \x' ->
1356+
Exists (\_ -> fatalError "mapSpec") (x :-> fold [Assert $ (x' ==. appTerm f (V x)), p])
1357+
mapSpec f (TypeSpec ts cant) = mapTypeSpec f ts <> notMemberSpec (map (semantics f) cant)
1358+
1359+
-- TODO generalizeme!
1360+
forwardPropagateSpec :: HasSpec a => Specification a -> Ctx a b -> Specification b
1361+
forwardPropagateSpec s CtxHOLE = s
1362+
forwardPropagateSpec s (CtxApp f (c :? Nil))
1363+
| Evidence <- ctxHasSpec c = mapSpec f (forwardPropagateSpec s c)
1364+
forwardPropagateSpec _ _ = TrueSpec
1365+

src/Constrained/TheKnot.hs

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ module Constrained.TheKnot (
5151
rangeSize,
5252
hasSize,
5353
genInverse,
54-
mapSpec,
5554
between,
5655

5756
-- * Patterns
@@ -77,7 +76,6 @@ import Constrained.SumList
7776
-- Because it is mutually recursive with something else in here.
7877
import Constrained.Syntax
7978
import Control.Applicative
80-
import Control.Monad
8179
import Data.Foldable
8280
import Data.Kind
8381
import Data.List (nub)
@@ -99,27 +97,6 @@ ifElse b p q = whenTrue b p <> whenTrue (not_ b) q
9997

10098
-- =======================================================================================
10199

102-
-- | Functor like property for Specification, but instead of a Haskell function (a -> b),
103-
-- it takes a function symbol (t '[a] b) from a to b.
104-
-- Note, in this context, a function symbol is some constructor of a witnesstype.
105-
-- Eg. ProdFstW, InjRightW, SingletonW, etc. NOT the lifted versions like fst_ singleton_,
106-
-- which construct Terms. We had to wait until here to define this because it
107-
-- depends on Semigroup property of Specification, and Asserting equality
108-
mapSpec ::
109-
forall t a b.
110-
AppRequires t '[a] b =>
111-
t '[a] b ->
112-
Specification a ->
113-
Specification b
114-
mapSpec f (ExplainSpec es s) = explainSpec es (mapSpec f s)
115-
mapSpec f TrueSpec = mapTypeSpec f (emptySpec @a)
116-
mapSpec _ (ErrorSpec err) = ErrorSpec err
117-
mapSpec f (MemberSpec as) = MemberSpec $ NE.nub $ fmap (semantics f) as
118-
mapSpec f (SuspendedSpec x p) =
119-
constrained $ \x' ->
120-
Exists (\_ -> fatalError "mapSpec") (x :-> fold [Assert $ (x' ==. appTerm f (V x)), p])
121-
mapSpec f (TypeSpec ts cant) = mapTypeSpec f ts <> notMemberSpec (map (semantics f) cant)
122-
123100
-- ================================================================
124101
-- HasSpec for Products
125102
-- ================================================================

0 commit comments

Comments
 (0)