diff --git a/README.md b/README.md index a64f457..c8a87a0 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,8 @@ See the [youtube presentation](https://youtu.be/cESdgot_ZxY) for more details ab [This lecture](https://youtu.be/v=KcfD3Iv--UM) is a pedagogical explanation of the Curry-Howard correspondence in the context of functional programming. +See also a [recent presentation at the Haskell User's Group meetup](https://youtu.be/OFBwrMo1ESk). + # Unit tests `sbt test` @@ -156,6 +158,7 @@ Build the tutorial (thanks to the [tut plugin](https://github.com/tpolecat/tut)) # Revision history +- 0.3.8 Support Scala 2.13 (keep supporting Scala 2.11 and 2.12) - 0.3.7 Implement the `typeExpr` macro instead of the old test-only API. Detect and use `val`s from the immediately enclosing class. Minor performance improvements and bug fixes (alpha-conversion for STLC terms). Tests for automatic discovery of some monads. - 0.3.6 STLC terms are now emitted for `implement` as well; the JVM bytecode limit is obviated; fixed bug with recognizing `Function10`. - 0.3.5 Added `:@@` and `@@:` operations to the STLC interpreter. Fixed a bug whereby `Tuple2(x._1, x._2)` was not simplified to `x`. Fixed other bugs in alpha-conversion of type parameters. diff --git a/build.sbt b/build.sbt index ecd1f19..5dd9f8a 100644 --- a/build.sbt +++ b/build.sbt @@ -124,7 +124,7 @@ lazy val curryhoward: Project = (project in file(".")) .settings(common) .settings( organization := "io.chymyst", - version := "0.3.8", + version := "0.3.9", licenses := Seq("Apache License, Version 2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0.txt")), homepage := Some(url("https://github.com/Chymyst/curryhoward")), @@ -157,14 +157,14 @@ lazy val curryhoward: Project = (project in file(".")) ///////////////////////////////////////////////////////////////////////////////////////////////////// // Publishing to Sonatype Maven repository publishMavenStyle := true -publishTo := sonatypePublishToBundle.value -/*{ +publishTo := //sonatypePublishToBundle.value +{ val nexus = "https://oss.sonatype.org/" if (isSnapshot.value) Some("snapshots" at nexus + "content/repositories/snapshots") else Some("releases" at nexus + "service/local/staging/deploy/maven2") -}*/ +} // publishArtifact in Test := false // diff --git a/docs/Tutorial.md b/docs/Tutorial.md index 34706e2..4d28e93 100644 --- a/docs/Tutorial.md +++ b/docs/Tutorial.md @@ -888,8 +888,8 @@ res26: Boolean = true | `a.substTypeVar(b, c)` | `TermExpr ⇒ (TermExpr, TermExpr) ⇒ TermExpr` | replace a type variable in `a`; the type variable is specified as the type of `b`, and the replacement type is specified as the type of `c` | | `a.substTypeVars(s)` | `TermExpr ⇒ Map[TP, TypeExpr] ⇒ TermExpr` | replace all type variables in `a` according to the given substitution map `s` -- all type variables are substituted at once | | `u()` | `TermExpr ⇒ () ⇒ TermExpr` and `TypeExpr ⇒ () ⇒ TermExpr` | create a "named Unit" term of type `u.t` -- the type of `u` must be a named unit type, e.g. `None.type` or a case class with no constructors | -| `c(x...)` | `TermExpr ⇒ TermExpr* ⇒ TermExpr` and `TypeExpr ⇒ TermExpr* ⇒ TermExpr` | create a named conjunction term of type `c.t` -- the type of `c` must be a conjunction whose parts match the types of the arguments `x...` | -| `d(x)` | `TermExpr ⇒ TermExpr ⇒ TermExpr` and `TypeExpr ⇒ TermExpr ⇒ TermExpr` | create a disjunction term of type `d.t` using term `x` -- the type of `x` must match one of the disjunction parts in the type `d`, which must be a disjunction type | +| `c(x...)` | `TypeExpr ⇒ TermExpr* ⇒ TermExpr` and `TypeExpr ⇒ TermExpr* ⇒ TermExpr` | create a named conjunction term of type `c` -- the type `c` must be a conjunction whose parts match the types of the arguments `x...` | +| `d(x)` | `TypeExpr ⇒ TermExpr ⇒ TermExpr` and `TypeExpr ⇒ TermExpr ⇒ TermExpr` | create a disjunction term of type `d` using term `x` -- the type of `x` must match one of the disjunction parts in the type `d`, which must be a disjunction type | | `c(i)` | `TermExpr ⇒ Int ⇒ TermExpr` | project a conjunction term onto part with given zero-based index -- the type of `c` must be a conjunction with sufficiently many parts | | `c("id")` | `TermExpr ⇒ String ⇒ TermExpr` | project a conjunction term onto part with given accessor name -- the type of `c` must be a named conjunction that supports this accessor | | `d.cases(x =>: ..., y =>: ..., ...)` | `TermExpr ⇒ TermExpr* ⇒ TermExpr` | create a term that pattern-matches on the given disjunction term -- the type of `d` must be a disjunction whose arguments match the arguments `x`, `y`, ... of the given case clauses | diff --git a/src/main/scala/io/chymyst/ch/TermExpr.scala b/src/main/scala/io/chymyst/ch/TermExpr.scala index 6829ca4..d9aba93 100644 --- a/src/main/scala/io/chymyst/ch/TermExpr.scala +++ b/src/main/scala/io/chymyst/ch/TermExpr.scala @@ -145,7 +145,7 @@ object TermExpr { thisVar.name === otherVar.name && (thisVar.t === otherVar.t || TypeExpr.isDisjunctionPart(thisVar.t, otherVar.t)) } - /** Replace all non-free occurrences of variable `replaceVar` by expression `byExpr` in `origExpr`. + /** Replace all free occurrences of variable `replaceVar` by expression `byExpr` in `origExpr`. * * @param replaceVar A variable that may occur freely in `origExpr`. * @param byExpr A new expression to replace all free occurrences of that variable. @@ -156,7 +156,7 @@ object TermExpr { // Check that all instances of replaceVar in origExpr have the correct type. val badVars = origExpr.freeVars.filter(_.name === replaceVar.name).filterNot(varMatchesType(_, replaceVar)) if (badVars.nonEmpty) { - throw new Exception(s"In subst($replaceVar, $byExpr, $origExpr), found variable(s) ${badVars.map(v ⇒ s"(${v.name}:${v.t.prettyPrint})").mkString(", ")} with incorrect type(s), expected variable type ${replaceVar.t.prettyPrint}") + throw new Exception(s"In subst($replaceVar:${replaceVar.t.prettyPrint}, $byExpr, $origExpr), found variable(s) ${badVars.map(v ⇒ s"(${v.name}:${v.t.prettyPrint})").mkString(", ")} with incorrect type(s), expected variable type ${replaceVar.t.prettyPrint}") } // Do we need an alpha-conversion? Better be safe than sorry. val (convertedReplaceVar, convertedOrigExpr) = if (byExpr.usedVarNames contains replaceVar.name) { @@ -167,7 +167,7 @@ object TermExpr { substMap(convertedOrigExpr) { case c@CurriedE(heads, _) if heads.exists(_.name === convertedReplaceVar.name) ⇒ c - // If a variable from `heads` collides with `convertedReplaceVar`, we do not replace anything in the body. + // If a variable from `heads` collides with `convertedReplaceVar`, we do not replace anything in the body because the variable occurs as non-free. case v@VarE(_, _) if varMatchesType(v, convertedReplaceVar) ⇒ byExpr } @@ -342,6 +342,43 @@ object TermExpr { } private[ch] def roundFactor(x: Double): Int = math.round(x * 10000).toInt + + /** Generate all necessary fresh variables for equality checking of functions that consume disjunction types. + * + * @param typeExpr The type of the argument expression. + * @return A sequence of [[TermExpr]] values containing the necessary fresh variables. + */ + def subtypeVars(typeExpr: TypeExpr): Seq[TermExpr] = typeExpr match { + case dt@DisjunctT(_, _, terms) ⇒ terms.zipWithIndex.flatMap { case (t, i) ⇒ subtypeVars(t).map(v ⇒ DisjunctE(i, terms.length, v, dt)) } + case nct@NamedConjunctT(_, _, _, wrapped) ⇒ + TheoremProver.explode(wrapped.map(subtypeVars)).map(NamedConjunctE(_, nct)) + case _ ⇒ Seq(VarE(freshIdents(), typeExpr)) + } + + /** Extensional equality check. If the term expressions are functions, fresh variables are substituted as arguments and the results are compared with `equiv`. + * + * @param termExpr1 The first term. + * @param termExpr2 The second term. + * @return `true` if the terms are extensionally equal. + */ + def extEqual(termExpr1: TermExpr, termExpr2: TermExpr): Boolean = { + val t1 = termExpr1.simplify + val t2 = termExpr2.simplify + (t1.t === t2.t) && ( + (t1 equiv t2) || { + println(s"DEBUG: checking extensional equality of ${t1.prettyPrint} and ${t2.prettyPrint}") + (t1, t2) match { + case (CurriedE(h1 :: _, _), CurriedE(_ :: _, _)) ⇒ + subtypeVars(h1.t).forall { term ⇒ + val result = extEqual(t1(term), t2(term)) + if (!result) println(s"DEBUG: found inequality after substituting term ${term.prettyPrint}") + result + } + case _ ⇒ false + } + } + ) + } } sealed trait TermExpr { @@ -507,12 +544,33 @@ sealed trait TermExpr { "(" + leftZeros.mkString(" + ") + leftZerosString + term.prettyPrintWithParentheses(0) + rightZerosString + rightZeros.mkString(" + ") + ")" } + lazy val printScala: String = printScalaWithTypes() + + private[ch] def printScalaWithTypes(withTypes: Boolean = false): String = this match { + case VarE(name, _) ⇒ name + (if (withTypes) ": " + t.prettyPrint else "") + case AppE(head, arg) ⇒ + val h = head.printScalaWithTypes(true) + val b = arg.printScalaWithTypes() + s"$h($b)" + case CurriedE(heads, body) ⇒ + s"${heads.map(_.printScalaWithTypes(true)).mkString(" ⇒ ")} ⇒ ${body.printScalaWithTypes()}" + case UnitE(_) ⇒ "()" + case ConjunctE(terms) ⇒ "(" + terms.map(_.printScalaWithTypes()).mkString(", ") + ")" + case NamedConjunctE(terms, tExpr) ⇒ if (tExpr.wrapped.isEmpty) tExpr.constructor.toString + else s"${tExpr.constructor.toString}(${terms.map(_.printScalaWithTypes()).mkString(", ")})" + case ProjectE(index, term) ⇒ term.printScalaWithTypes() + "." + term.accessor(index) + case MatchE(term, cases) ⇒ + term.printScalaWithTypes() + " match { case " + cases.map(_.printScalaWithTypes(true)).mkString("; case ") + " }" + case DisjunctE(index, total, term, _) ⇒ + term.printScalaWithTypes() + } + private def prettyVars: Iterator[String] = for { number ← Iterator.single("") ++ Iterator.from(1).map(_.toString) letter ← ('a' to 'z').toIterator } yield s"$letter$number" - private lazy val renameBoundVars: TermExpr = TermExpr.substMap(this) { + private[ch] lazy val renameBoundVars: TermExpr = TermExpr.substMap(this) { case CurriedE(heads, body) ⇒ val oldAndNewVars = heads.map { v ⇒ (v, VarE(TermExpr.freshIdents(), v.t)) } val renamedBody = oldAndNewVars.foldLeft(body.renameBoundVars) { case (prev, (oldVar, newVar)) ⇒ @@ -803,36 +861,57 @@ final case class MatchE(term: TermExpr, cases: List[TermExpr]) extends TermExpr } private[ch] override def simplifyOnceInternal(withEta: Boolean): TermExpr = { - lazy val casesSimplified = cases.map(_.simplifyOnce(withEta)) + val ncases = cases.length term.simplifyOnce(withEta) match { // Match a fixed part of the disjunction; can be simplified to just one clause. // Example: Left(a) match { case Left(x) => f(x); case Right(y) => ... } can be simplified to just f(a). case DisjunctE(index, total, termInjected, _) ⇒ - if (total === cases.length) { + if (total === ncases) { AppE(cases(index).simplifyOnce(withEta), termInjected).simplifyOnce(withEta) - } else throw new Exception(s"Internal error: MatchE with ${cases.length} cases applied to DisjunctE with $total parts, but must be of equal size") + } else throw new Exception(s"Internal error: MatchE with $ncases cases applied to DisjunctE with $total parts, but must be of equal size") // Match of an inner match, can be simplified to a single match. - // Example: (Left(a) match { case Left(x) ⇒ ...; case Right(y) ⇒ ... }) match { case ... ⇒ ... } - // can be simplified to Left(a) match { case Left(x) ⇒ ... match { case ... ⇒ ...}; case Right(y) ⇒ ... match { case ... ⇒ ... } } + // Example: (q match { case Left(x) ⇒ ...; case Right(y) ⇒ ... }) match { case ... ⇒ ... } + // can be simplified to q match { case Left(x) ⇒ ... match { case ... ⇒ ...}; case Right(y) ⇒ ... match { case ... ⇒ ... } } case MatchE(innerTerm, innerCases) ⇒ MatchE(innerTerm, innerCases map { case CurriedE(List(head), body) ⇒ CurriedE(List(head), MatchE(body, cases)) }) + .simplifyOnce(withEta) // Detect the identity patterns: // MatchE(_, List(a ⇒ DisjunctE(0, total, a, _), a ⇒ DisjunctE(1, total, a, _), ...)) // MatchE(_, a: T1 ⇒ DisjunctE(i, total, NamedConjunctE(List(ProjectE(0, a), Project(1, a), ...), T1), ...), _) case termSimplified ⇒ - if (cases.nonEmpty && { + + // Replace redundant matches on the same term, can be simplified by eliminating one match subexpresssion. + // Example: q match { case x ⇒ q match { case y ⇒ b; case other ⇒ ... } ... } + // We already know that q was matched as Left(x). Therefore, we can replace y by x in b and remove the `case other` clause altogether. + // Doing a .renameBoundVars on the cases leads to infinite loops somewhere due to incorrect alpha-conversion. + val casesSimplified = cases.map(_.simplifyOnce(withEta)) + /* + .zipWithIndex.map { case (c@CurriedE(List(headVar), _), i) ⇒ + TermExpr.substMap(c) { + case MatchE(otherTerm, otherCases) if otherTerm === termSimplified ⇒ + // We already matched `otherTerm`, and we are now in case `c`, which is `case x ⇒ ...`. + // Therefore we can discard any of the `otherCases` except the one corresponding to `c`. + // We can replace the `q match { case y ⇒ b; ...}` by `b` after replacing `x` by `y` in `b`. + val remainingCase = otherCases(i) + val result = AppE(remainingCase, headVar).simplifyOnce(withEta) + // println(s"DEBUG: replacing ${MatchE(otherTerm, otherCases)} by $result in ${c.simplifyOnce(withEta)}") + result + } + } + */ + if (casesSimplified.nonEmpty && { casesSimplified.zipWithIndex.forall { // Detect a ⇒ a pattern case (CurriedE(List(head@VarE(_, _)), body@VarE(_, _)), _) if head.name === body.name ⇒ true case (CurriedE(List(head@VarE(_, _)), DisjunctE(i, len, x, _)), ind) - if x === head && len === cases.length && ind === i + if x === head && len === ncases && ind === i ⇒ true case (CurriedE(List(head@VarE(_, headT)), DisjunctE(i, len, NamedConjunctE(projectionTerms, conjT), _)), ind) ⇒ - len === cases.length && ind === i && headT === conjT && + len === ncases && ind === i && headT === conjT && projectionTerms.zipWithIndex.forall { case (ProjectE(k, head1), j) if k === j && head1 === head ⇒ true case _ ⇒ false diff --git a/src/main/scala/io/chymyst/ch/TheoremProver.scala b/src/main/scala/io/chymyst/ch/TheoremProver.scala index 6d7e53a..5317f20 100644 --- a/src/main/scala/io/chymyst/ch/TheoremProver.scala +++ b/src/main/scala/io/chymyst/ch/TheoremProver.scala @@ -93,9 +93,10 @@ object TheoremProver { val transformedProofs = explodedNewProofs.map(ruleResult.backTransform) val t1 = System.currentTimeMillis() - val result = transformedProofs.sortBy(_.informationLossScore).take(maxTermsToSelect(sequent)) + val result = transformedProofs.map(_.simplifyOnce(withEta = false)).distinct.sortBy(_.informationLossScore).take(maxTermsToSelect(sequent)) // Note: at this point, it is a mistake to do prettyRename, because we are calling this function recursively. // We will call prettyRename() at the very end of the proof search. + // It is also a mistake to do a `.simplifyOnce(withEta = true)`. The eta-conversion produces incorrect code here. if (debug) { println(s"DEBUG: elapsed ${System.currentTimeMillis() - t0} ms, .map(_.simplify()).distinct took ${System.currentTimeMillis() - t1} ms, produced ${result.size} terms out of ${transformedProofs.size} back-transformed terms; after rule ${ruleResult.ruleName} for sequent $sequent") // println(s"DEBUG: for sequent $sequent, after rule ${ruleResult.ruleName}, transformed ${transformedProofs.length} proof terms:\n ${transformedProofs.mkString(" ;\n ")} ,\nafter simplifying:\n ${result.mkString(" ;\n ")} .") diff --git a/src/main/scala/io/chymyst/ch/data/CategoryTheory.scala b/src/main/scala/io/chymyst/ch/data/CategoryTheory.scala new file mode 100644 index 0000000..543d660 --- /dev/null +++ b/src/main/scala/io/chymyst/ch/data/CategoryTheory.scala @@ -0,0 +1,75 @@ +package io.chymyst.ch.data + +// Declarations of standard type classes, to be used in macros. + +trait Semigroup[T] { + def combine(x: T, y: T): T +} + +trait Monoid[T] extends Semigroup[T] { + def empty: T +} + +object Monoid { + def empty[T](implicit ev: Monoid[T]): T = ev.empty + + implicit class MonoidSyntax[T](t: T)(implicit ev: Monoid[T]) { + + def combine(y: T): T = ev.combine(t, y) + } + +} + +trait Functor[F[_]] { + def map[A, B](fa: F[A])(f: A ⇒ B): F[B] +} + +trait ContraFunctor[F[_]] { + def map[A, B](fa: F[A])(f: B ⇒ A): F[B] +} + +trait Filterable[F[_]] extends Functor[F] { + def deflate[A](fa: F[Option[A]]): F[A] +} + +trait ContraFilterable[F[_]] extends ContraFunctor[F] { + def inflate[A](fa: F[A]): F[Option[A]] +} + +trait Semimonad[F[_]] extends Functor[F] { + def join[A](ffa: F[F[A]]): F[A] +} + +trait Pointed[F[_]] extends Functor[F] { + def point[A]: F[A] +} + +trait Zippable[F[_]] extends Functor[F] { + def zip[A, B](fa: F[A], fb: F[B]): F[(A, B)] +} + +trait Foldable[F[_]] extends Functor[F] { + def foldMap[A, B: Monoid](fa: F[A])(f: A ⇒ B) +} + +trait Traversable[F[_]] extends Functor[F] { + def sequence[Z[_] : Zippable, A](fga: F[Z[A]]): Z[F[A]] +} + +trait Monad[F[_]] extends Pointed[F] with Semimonad[F] + +trait Applicative[F[_]] extends Pointed[F] with Zippable[F] + +trait Cosemimonad[F[_]] extends Functor[F] { + def cojoin[A](fa: F[A]): F[F[A]] +} + +trait Copointed[F[_]] extends Functor[F] { + def extract[A](fa: F[A]): A +} + +trait Comonad[F[_]] extends Copointed[F] with Cosemimonad[F] + +trait Cozippable[F[_]] extends Functor[F] { + def decide[A, B](fab: F[Either[A, B]]): Either[F[A], F[B]] +} diff --git a/src/main/scala/io/chymyst/ch/data/Monoid.scala b/src/main/scala/io/chymyst/ch/data/Monoid.scala deleted file mode 100644 index dba460d..0000000 --- a/src/main/scala/io/chymyst/ch/data/Monoid.scala +++ /dev/null @@ -1,17 +0,0 @@ -package io.chymyst.ch.data - -trait Monoid[T] { - def empty: T - - def combine(x: T, y: T): T -} - -object Monoid { - def empty[T](implicit ev: Monoid[T]): T = ev.empty - - implicit class MonoidSyntax[T](t: T)(implicit ev: Monoid[T]) { - - def combine(y: T): T = ev.combine(t, y) - } - -} diff --git a/src/main/scala/io/chymyst/ch/data/LawChecking.scala b/src/main/scala/io/chymyst/ch/data/SymbolicLawChecking.scala similarity index 89% rename from src/main/scala/io/chymyst/ch/data/LawChecking.scala rename to src/main/scala/io/chymyst/ch/data/SymbolicLawChecking.scala index 18fc221..6949d33 100644 --- a/src/main/scala/io/chymyst/ch/data/LawChecking.scala +++ b/src/main/scala/io/chymyst/ch/data/SymbolicLawChecking.scala @@ -2,14 +2,14 @@ package io.chymyst.ch.data import io.chymyst.ch._ -object LawChecking { +object SymbolicLawChecking { def checkFlattenAssociativity(fmap: TermExpr, flatten: TermExpr): Boolean = { // fmap ftn . ftn = ftn . ftn val lhs = flatten :@@ flatten val rhs = (fmap :@ flatten) :@@ flatten // println(s"check associativity laws for flatten = ${flatten.prettyPrint}:\n\tlhs = ${lhs.simplify.prettyRenamePrint}\n\trhs = ${rhs.simplify.prettyRenamePrint}") - lhs equiv rhs + TermExpr.extEqual(lhs, rhs) } def checkPureFlattenLaws(fmap: TermExpr, pure: TermExpr, flatten: TermExpr): Boolean = { @@ -23,7 +23,7 @@ object LawChecking { val fpf = (fmap :@ pure) :@@ flatten // println(s"check identity laws for pure = ${pure.prettyPrint} and flatten = ${flatten.prettyPrint}:\n\tlhs1 = ${pf.simplify.prettyPrint}\n\trhs1 = ${idFA.simplify.prettyPrint}\n\tlhs2 = ${fpf.simplify.prettyPrint}\n\trhs2 = ${idFA.simplify.prettyPrint}") - (pf equiv idFA) && (fpf equiv idFA) + TermExpr.extEqual(pf, idFA) && TermExpr.extEqual(fpf, idFA) } } diff --git a/src/main/tut/Tutorial.md b/src/main/tut/Tutorial.md index 9185776..1a37837 100644 --- a/src/main/tut/Tutorial.md +++ b/src/main/tut/Tutorial.md @@ -725,8 +725,8 @@ getIdAutoTerm equiv getId.prettyRename | `a.substTypeVar(b, c)` | `TermExpr ⇒ (TermExpr, TermExpr) ⇒ TermExpr` | replace a type variable in `a`; the type variable is specified as the type of `b`, and the replacement type is specified as the type of `c` | | `a.substTypeVars(s)` | `TermExpr ⇒ Map[TP, TypeExpr] ⇒ TermExpr` | replace all type variables in `a` according to the given substitution map `s` -- all type variables are substituted at once | | `u()` | `TermExpr ⇒ () ⇒ TermExpr` and `TypeExpr ⇒ () ⇒ TermExpr` | create a "named Unit" term of type `u.t` -- the type of `u` must be a named unit type, e.g. `None.type` or a case class with no constructors | -| `c(x...)` | `TermExpr ⇒ TermExpr* ⇒ TermExpr` and `TypeExpr ⇒ TermExpr* ⇒ TermExpr` | create a named conjunction term of type `c.t` -- the type of `c` must be a conjunction whose parts match the types of the arguments `x...` | -| `d(x)` | `TermExpr ⇒ TermExpr ⇒ TermExpr` and `TypeExpr ⇒ TermExpr ⇒ TermExpr` | create a disjunction term of type `d.t` using term `x` -- the type of `x` must match one of the disjunction parts in the type `d`, which must be a disjunction type | +| `c(x...)` | `TypeExpr ⇒ TermExpr* ⇒ TermExpr` and `TypeExpr ⇒ TermExpr* ⇒ TermExpr` | create a named conjunction term of type `c` -- the type `c` must be a conjunction whose parts match the types of the arguments `x...` | +| `d(x)` | `TypeExpr ⇒ TermExpr ⇒ TermExpr` and `TypeExpr ⇒ TermExpr ⇒ TermExpr` | create a disjunction term of type `d` using term `x` -- the type of `x` must match one of the disjunction parts in the type `d`, which must be a disjunction type | | `c(i)` | `TermExpr ⇒ Int ⇒ TermExpr` | project a conjunction term onto part with given zero-based index -- the type of `c` must be a conjunction with sufficiently many parts | | `c("id")` | `TermExpr ⇒ String ⇒ TermExpr` | project a conjunction term onto part with given accessor name -- the type of `c` must be a named conjunction that supports this accessor | | `d.cases(x =>: ..., y =>: ..., ...)` | `TermExpr ⇒ TermExpr* ⇒ TermExpr` | create a term that pattern-matches on the given disjunction term -- the type of `d` must be a disjunction whose arguments match the arguments `x`, `y`, ... of the given case clauses | diff --git a/src/test/scala/io/chymyst/ch/unit/LJTSpec3.scala b/src/test/scala/io/chymyst/ch/unit/LJTSpec3.scala index 157f402..d426ca2 100644 --- a/src/test/scala/io/chymyst/ch/unit/LJTSpec3.scala +++ b/src/test/scala/io/chymyst/ch/unit/LJTSpec3.scala @@ -123,6 +123,14 @@ class LJTSpec3 extends FlatSpec with Matchers with BeforeAndAfterEach { f3.size shouldEqual 1 } + it should "generate methods for State monad with Int state" in { + final case class S_int[A](run: Int => (A, Int)) // State monad with internal state of type Int. + val wu1: S_int[Unit] = S_int { i => ((), i) } + wu1.run(123) shouldEqual (((), 123)) + val wu2: S_int[Unit] = implement // Expect the same + wu2.run(123) shouldEqual (((), 123)) + } + it should "generate methods for Continuation monad with no ambiguity" in { case class Cont[X, R](c: (X ⇒ R) ⇒ R) diff --git a/src/test/scala/io/chymyst/ch/unit/LJTSpec4.scala b/src/test/scala/io/chymyst/ch/unit/LJTSpec4.scala index 0e2edc5..6fa78c2 100644 --- a/src/test/scala/io/chymyst/ch/unit/LJTSpec4.scala +++ b/src/test/scala/io/chymyst/ch/unit/LJTSpec4.scala @@ -1,7 +1,7 @@ package io.chymyst.ch.unit import io.chymyst.ch._ -import io.chymyst.ch.data.{LawChecking => LC} +import io.chymyst.ch.data.{SymbolicLawChecking => LC} import org.scalatest.{FlatSpec, Matchers} class LJTSpec4 extends FlatSpec with Matchers { diff --git a/src/test/scala/io/chymyst/ch/unit/LambdaTermsSpec.scala b/src/test/scala/io/chymyst/ch/unit/LambdaTermsSpec.scala index 6d8bfdb..d9795d3 100644 --- a/src/test/scala/io/chymyst/ch/unit/LambdaTermsSpec.scala +++ b/src/test/scala/io/chymyst/ch/unit/LambdaTermsSpec.scala @@ -1,7 +1,7 @@ package io.chymyst.ch.unit import io.chymyst.ch._ -import io.chymyst.ch.data.{LawChecking => LC} +import io.chymyst.ch.data.{SymbolicLawChecking => LC} import org.scalatest.{Assertion, FlatSpec, Matchers} class LambdaTermsSpec extends FlatSpec with Matchers { @@ -604,7 +604,7 @@ class LambdaTermsSpec extends FlatSpec with Matchers { // Compute flatten terms from flm terms val ftnTerms = flmTerms.map(flm ⇒ (flm :@ (px =>: px)).simplify) - if (debug) println(s"flatten terms: ${ftnTerms.map(_.prettyPrint)}") + if (debug) println(s"flatten terms (type ${ftnTerms.head.t.prettyPrint}):\n\t${ftnTerms.map(_.prettyPrint).mkString("\n\t")}") val pureTerms = TheoremProver.findProofs(pureVar.t)._2 if (debug) println(s"pure terms: ${pureTerms.map(_.prettyPrint)}") @@ -660,7 +660,7 @@ class LambdaTermsSpec extends FlatSpec with Matchers { println("Good monads:") println(goodMonads.map { case (pure, ftn) ⇒ s"pure = ${pure.prettyPrint}, flatten = ${ftn.prettyPrint}" }.mkString("\n")) - goodSemimonads.size shouldEqual 2 + goodSemimonads.size shouldEqual 3 goodMonads.size shouldEqual 1 } @@ -673,7 +673,7 @@ class LambdaTermsSpec extends FlatSpec with Matchers { def pure[A] = freshVar[A ⇒ P[A]] - val (goodSemimonads, goodMonads) = semimonadsAndMonads(fmapTerm, pure, flm) + val (goodSemimonads, goodMonads) = semimonadsAndMonads(fmapTerm, pure, flm, debug = true) println(s"Good semimonads (flatten):\n${goodSemimonads.map(_.prettyPrint).mkString("\n")}") @@ -720,7 +720,7 @@ class LambdaTermsSpec extends FlatSpec with Matchers { println("Good monads:") println(goodMonads.map { case (pure, ftn) ⇒ s"pure = ${pure.prettyPrint}, flatten = ${ftn.prettyPrint}" }.mkString("\n")) - goodSemimonads.size shouldEqual 13 + goodSemimonads.size shouldEqual 19 goodMonads.size shouldEqual 6 } @@ -744,6 +744,26 @@ class LambdaTermsSpec extends FlatSpec with Matchers { goodMonads.size shouldEqual 2 } + it should "check A x A + A x A x A monad" in { + type P[A] = Either[(A, A), (A, A, A)] + + def fmapTerm[A, B] = ofType[(A ⇒ B) ⇒ P[A] ⇒ P[B]].lambdaTerm + + def flm[A, B] = freshVar[(A ⇒ P[B]) ⇒ P[A] ⇒ P[B]] + + def pure[A] = freshVar[A ⇒ P[A]] + + val (goodSemimonads, goodMonads) = semimonadsAndMonads(fmapTerm, pure, flm) + + println(s"Good semimonads (flatten):\n${goodSemimonads.map(_.prettyPrint).mkString("\n")}") + + println("Good monads:") + println(goodMonads.map { case (pure, ftn) ⇒ s"pure = ${pure.prettyPrint}, flatten = ${ftn.prettyPrint}" }.mkString("\n")) + + goodSemimonads.size shouldEqual 12 + goodMonads.size shouldEqual 2 + } + it should "check A + (1 ⇒ A) monad" in { type P[A] = Either[A, Unit ⇒ A] @@ -760,7 +780,76 @@ class LambdaTermsSpec extends FlatSpec with Matchers { println("Good monads:") println(goodMonads.map { case (pure, ftn) ⇒ s"pure = ${pure.prettyPrint}, flatten = ${ftn.prettyPrint}" }.mkString("\n")) - goodSemimonads.size shouldEqual 2 + goodSemimonads.size shouldEqual 3 goodMonads.size shouldEqual 1 } + + it should "check the 1 + A x A monad by hand" in { + type P[A] = Option[(A, A)] + + def fmap[A, B] = ofType[(A ⇒ B) ⇒ P[A] ⇒ P[B]].lambdaTerm + + fmap.prettyPrint shouldEqual "a ⇒ b ⇒ b match { c ⇒ (None() + 0); d ⇒ (0 + Some(Tuple2(a d.value._1, a d.value._2))) }" + + def pure[A] = ofType[A ⇒ P[A]].lambdaTerm + + pure.prettyPrint shouldEqual "a ⇒ (0 + Some(Tuple2(a, a)))" + + // Construct flatten by hand. + def ppa[A] = freshVar[P[P[A]]] + + def pa[A] = freshVar[P[A]] + + val none = freshVar[None.type] + val ppaNone = ppa(none) + val paNone = pa(none) + ppaNone.t.prettyPrint shouldEqual "Option[Tuple2[Option[Tuple2[A,A]],Option[Tuple2[A,A]]]]" + ppa.t.prettyPrint shouldEqual ppaNone.t.prettyPrint + + def tuplePa[A] = freshVar[Some[(P[A], P[A])]] + + def tupleA0[A] = freshVar[Some[(A, A)]] + + def tupleA1[A] = freshVar[Some[(A, A)]] + + def tupleA2[A] = freshVar[Some[(A, A)]] + + val ftn = ( + ppa =>: ppa.cases( + none =>: paNone, + tuplePa =>: tuplePa("value")(0).cases( + none =>: paNone, + tupleA0 =>: tuplePa("value")(1).cases( + none =>: paNone, + tupleA1 =>: + pa(tupleA0(tupleA0("value").t(tupleA0("value")(0), tupleA1("value")(1)))) + ) + ) + ) + ).prettyRename + ftn.t.prettyPrint shouldEqual "Option[Tuple2[Option[Tuple2[A,A]],Option[Tuple2[A,A]]]] ⇒ Option[Tuple2[A,A]]" + ftn.prettyPrint shouldEqual "a ⇒ a match { b ⇒ (b + 0); c ⇒ c.value._1 match { d ⇒ (d + 0); e ⇒ c.value._2 match { f ⇒ (f + 0); g ⇒ (0 + Some(Tuple2(e.value._1, g.value._2))) } } }" + // Fails due to violating laws. + LC.checkPureFlattenLaws(fmap, pure, ftn) shouldEqual true + LC.checkFlattenAssociativity(fmap, ftn) shouldEqual false + + val ftn2 = ( + ppa =>: ppa.cases( + none =>: paNone, + tuplePa =>: tuplePa("value")(0).cases( + none =>: tuplePa("value")(1), + tupleA0 =>: tuplePa("value")(1).cases( + none =>: tuplePa("value")(0), + tupleA1 =>: + pa(tupleA0(tupleA0("value").t(tupleA0("value")(0), tupleA1("value")(1)))) + ) + ) + ) + ).simplify.prettyRename + ftn2.t.prettyPrint shouldEqual "Option[Tuple2[Option[Tuple2[A,A]],Option[Tuple2[A,A]]]] ⇒ Option[Tuple2[A,A]]" + ftn2.prettyPrint shouldEqual "a ⇒ a match { b ⇒ (b + 0); c ⇒ c.value._1 match { d ⇒ c.value._2; e ⇒ c.value._2 match { f ⇒ c.value._1; g ⇒ (0 + Some(Tuple2(e.value._1, g.value._2))) } } }" + LC.checkPureFlattenLaws(fmap, pure, ftn2) shouldEqual true + LC.checkFlattenAssociativity(fmap, ftn2) shouldEqual false + } + } diff --git a/src/test/scala/io/chymyst/ch/unit/TermExprSpec.scala b/src/test/scala/io/chymyst/ch/unit/TermExprSpec.scala index f3239a1..7510763 100644 --- a/src/test/scala/io/chymyst/ch/unit/TermExprSpec.scala +++ b/src/test/scala/io/chymyst/ch/unit/TermExprSpec.scala @@ -15,8 +15,43 @@ class TermExprSpec extends FlatSpec with Matchers { behavior of "TermExpr miscellaneous methods" + it should "generate variables for disjunction subtypes" in { + val p = freshVar[Either[Option[Option[(Int, Int)]], Option[(Option[Int], Option[Int])]]] + + val subtypeVars = TermExpr.subtypeVars(p.t).map(_.prettyPrint) + + val indices = "z([0-9]+)".r.findAllMatchIn(subtypeVars.mkString("")).map(_.group(1).toInt).toList + indices.map(_ - indices.min) shouldEqual Seq(0, 1, 3, 2, 2, 3) + + subtypeVars.map(_.replaceAll("z[0-9]+", "z")) shouldEqual List( + "(Left((None() + 0)) + 0)", + "(Left((0 + Some((None() + 0)))) + 0)", + "(Left((0 + Some((0 + Some(Tuple2(z, z)))))) + 0)", + "(0 + Right((None() + 0)))", "(0 + Right((0 + Some(Tuple2((None() + 0), (None() + 0))))))", + "(0 + Right((0 + Some(Tuple2((None() + 0), (0 + Some(z)))))))", + "(0 + Right((0 + Some(Tuple2((0 + Some(z)), (None() + 0))))))", + "(0 + Right((0 + Some(Tuple2((0 + Some(z)), (0 + Some(z)))))))" + ) + + val subtypeVarsScala = TermExpr.subtypeVars(p.t).map(_.printScala).map(_.replaceAll("z[0-9]+", "z")) shouldEqual + List("Left(None)", "Left(Some(None))", "Left(Some(Some(Tuple2(z, z))))", "Right(None)", "Right(Some(Tuple2(None, None)))", "Right(Some(Tuple2(None, Some(z))))", "Right(Some(Tuple2(Some(z), None)))", "Right(Some(Tuple2(Some(z), Some(z))))") + } + + it should "compute Scala code of flatten for Option" in { + def flattenOpt[A]: Option[Option[A]] ⇒ Option[A] = implement + + val flattenScala = flattenOpt.lambdaTerm.printScala + + flattenScala shouldEqual "a: Option[Option[A]] ⇒ a match { case b: None.type ⇒ None; case c: Some[Option[A]] ⇒ c.value }" + } + + it should "compute extensional equality of functions" in { + TermExpr.extEqual(TermExpr.id(typeExpr[Int]), TermExpr.id(typeExpr[Int])) shouldEqual true + } + it should "compute identity function" in { def idAB[A, B] = TermExpr.id(typeExpr[A ⇒ B]) + idAB.prettyPrint shouldEqual "x ⇒ x" idAB.toString shouldEqual "\\((x:A ⇒ B) ⇒ x)" idAB.t.prettyPrint shouldEqual "(A ⇒ B) ⇒ A ⇒ B" @@ -41,7 +76,7 @@ class TermExprSpec extends FlatSpec with Matchers { the[Exception] thrownBy { TermExpr.subst(var23, termExpr1, var32 =>: var23.copy(t = TP("1"))) shouldEqual (var32 =>: termExpr1) - } should have message "In subst(x2, \\((x2:3) ⇒ (x3:2) ⇒ (x4:1) ⇒ x3), \\((x3:2) ⇒ x2)), found variable(s) (x2:1) with incorrect type(s), expected variable type 3" + } should have message "In subst(x2:3, \\((x2:3) ⇒ (x3:2) ⇒ (x4:1) ⇒ x3), \\((x3:2) ⇒ x2)), found variable(s) (x2:1) with incorrect type(s), expected variable type 3" } it should "recover from incorrect substitution" in { @@ -58,6 +93,15 @@ class TermExprSpec extends FlatSpec with Matchers { have message "Incorrect substitution of bound variable x2 by non-variable Tuple2(x, x) in substMap(x2 ⇒ x1){...}" } + behavior of "printScala" + + it should "print functions in Scala syntax" in { + termExpr1.printScala shouldEqual "x2 ⇒ x3 ⇒ x4 ⇒ x3" + termExpr2.printScala shouldEqual "x2 ⇒ x3 ⇒ x4 ⇒ x1" + termExpr3.printScala shouldEqual "x1 ⇒ x2 ⇒ x3 ⇒ x4 ⇒ x1" + + } + behavior of "TermExpr#renameVar" it should "rename one variable" in { @@ -274,9 +318,9 @@ a ⇒ Tuple2(a._2._2, a._2._2) // Choose second element of second inner tuple. // println(flattens.size) - def f[A] = allOfType[Option[(A, A, A)] ⇒ Option[(A, A, A)]]() + def f[A] = anyOfType[Option[(A, A, A)] ⇒ Option[(A, A, A)]]() - println(f.size) + f.size shouldEqual 28 // f[Int].map(_.lambdaTerm.prettyPrint).sorted.foreach(println) // f.size shouldEqual factorial(4) } @@ -284,9 +328,7 @@ a ⇒ Tuple2(a._2._2, a._2._2) // Choose second element of second inner tuple. it should "generate match clauses" in { def f[A] = anyOfType[Option[Option[A]] ⇒ Option[Option[A]]]().map(_.lambdaTerm) - println(f.size) - f.map(_.prettyPrint).foreach(println) - + f.size shouldEqual 13 + // f.map(_.prettyPrint).foreach(println) } - }