diff --git a/deps/nlohmann_json b/deps/nlohmann_json index d33ecd3f..a0e9fb1e 160000 --- a/deps/nlohmann_json +++ b/deps/nlohmann_json @@ -1 +1 @@ -Subproject commit d33ecd3f3bd11e30aa8bbabb00e0a9cd3f2456d8 +Subproject commit a0e9fb1e638cfbb5b8b556b7c51eaa81977bad48 diff --git a/sol-core.cabal b/sol-core.cabal index a4f5ddff..852ac56b 100644 --- a/sol-core.cabal +++ b/sol-core.cabal @@ -61,7 +61,9 @@ library -- cabal-fmt: expand src exposed-modules: Solcore.Backend.EmitHull + Solcore.Backend.Retype Solcore.Backend.Specialise + Solcore.Backend.SpecMonad Solcore.Desugarer.FieldAccess Solcore.Desugarer.IfDesugarer Solcore.Desugarer.IndirectCall diff --git a/src/Common/Pretty.hs b/src/Common/Pretty.hs index 94f1900c..ce5618ae 100644 --- a/src/Common/Pretty.hs +++ b/src/Common/Pretty.hs @@ -7,6 +7,8 @@ module Common.Pretty , angles , pPrint , pShow +, prettys +, prettysWith ) where import Text.PrettyPrint hiding((<>)) import Text.PrettyPrint qualified as PP @@ -39,3 +41,9 @@ commaSepList = hsep . punctuate comma . map ppr angles :: Doc -> Doc angles d = char '<' >< d >< char '>' + +prettys :: Pretty a => [a] -> String +prettys = prettysWith ppr + +prettysWith :: (a -> Doc) -> [a] -> String +prettysWith pr = render . brackets . commaSep . map pr diff --git a/src/Solcore/Backend/EmitHull.hs b/src/Solcore/Backend/EmitHull.hs index cf6fbdd1..2bf71ac7 100644 --- a/src/Solcore/Backend/EmitHull.hs +++ b/src/Solcore/Backend/EmitHull.hs @@ -18,7 +18,7 @@ import Solcore.Frontend.TypeInference.TcMonad (insts) import Solcore.Frontend.TypeInference.TcSubst import Solcore.Frontend.TypeInference.TcUnify import Solcore.Primitives.Primitives -import Solcore.Backend.Specialise(typeOfTcExp) +import Solcore.Backend.Retype(typeOfTcExp) import System.Exit emitHull :: Bool -> TcEnv -> CompUnit Id -> IO [Hull.Object] diff --git a/src/Solcore/Backend/Retype.hs b/src/Solcore/Backend/Retype.hs new file mode 100644 index 00000000..f0bf27f7 --- /dev/null +++ b/src/Solcore/Backend/Retype.hs @@ -0,0 +1,68 @@ +module Solcore.Backend.Retype where + +import Solcore.Frontend.Syntax +import Solcore.Frontend.TypeInference.Id ( Id(..) ) +import Solcore.Primitives.Primitives +import Solcore.Frontend.Pretty.ShortName +import Solcore.Frontend.Pretty.SolcorePretty +import Common.Pretty + +type TcFunDef = FunDef Id +type TcExp = Exp Id + +typeOfTcExp :: TcExp -> Ty +typeOfTcExp (Var i) = idType i +typeOfTcExp (Con i []) = idType i +typeOfTcExp e@(Con i args) = go (idType i) args where + go ty [] = ty + go (_ :-> u) (a:as) = go u as + go _ _ = error $ "typeOfTcExp: " ++ show e +typeOfTcExp (Lit (IntLit _)) = word --TyCon "Word" [] +typeOfTcExp exp@(Call Nothing i args) = applyTo args funTy where + funTy = idType i + applyTo [] ty = ty + applyTo (_:as) (_ :-> u) = applyTo as u + applyTo _ _ = error $ concat [ "apply ", pretty i, " : ", pretty funTy + , "to", show $ map pretty args + , "\nIn:\n", show exp + ] +typeOfTcExp (Lam args body (Just tb)) = funtype tas tb where + tas = map typeOfTcParam args +typeOfTcExp (Cond _ _ e) = typeOfTcExp e +typeOfTcExp (TyExp _ ty) = ty +typeOfTcExp e = error $ "typeOfTcExp: " ++ show e + +typeOfTcStmt :: Stmt Id -> Ty +typeOfTcStmt (n := e) = unit +typeOfTcStmt (Let n _ _) = idType n +typeOfTcStmt (StmtExp e) = typeOfTcExp e +typeOfTcStmt (Return e) = typeOfTcExp e +typeOfTcStmt (Match _ ((pat, body):_)) = typeOfTcBody body + +typeOfTcBody :: [Stmt Id] -> Ty +typeOfTcBody [] = unit +typeOfTcBody [s] = typeOfTcStmt s +typeOfTcBody (_:b) = typeOfTcBody b + +typeOfTcParam :: Param Id -> Ty +typeOfTcParam (Typed i t) = idType i -- seems better than t - see issue #6 +typeOfTcParam (Untyped i) = idType i + +typeOfTcSignature :: Signature Id -> Ty +typeOfTcSignature sig = funtype (map typeOfTcParam $ sigParams sig) (returnType sig) where + returnType s = case sigReturn s of + Just t -> t + Nothing -> error ("no return type in signature of: " ++ show (sigName s)) + +schemeOfTcSignature :: Signature Id -> Scheme +schemeOfTcSignature sig@(Signature vs ps n args (Just rt)) + = case mapM getType args of + Just ts -> Forall vs (ps :=> (funtype ts rt)) + Nothing -> error $ unwords ["Invalid instance member signature:", pretty sig] + where + getType (Typed _ t) = Just t + getType _ = Nothing + +typeOfTcFunDef :: TcFunDef -> Ty +typeOfTcFunDef (FunDef sig _) = typeOfTcSignature sig + diff --git a/src/Solcore/Backend/SpecMonad.hs b/src/Solcore/Backend/SpecMonad.hs new file mode 100644 index 00000000..b8463925 --- /dev/null +++ b/src/Solcore/Backend/SpecMonad.hs @@ -0,0 +1,406 @@ +module Solcore.Backend.SpecMonad where +import Control.Applicative +import Control.Monad +import Control.Monad.Error.Class +import Control.Monad.State +import Data.Generics +import Data.List(union, (\\)) +import Data.Maybe(fromMaybe) +import qualified Data.Map as Map + +import Common.Monad +import Common.Pretty +import Solcore.Backend.Retype +import Solcore.Frontend.Pretty.ShortName +import Solcore.Frontend.Pretty.SolcorePretty +import Solcore.Frontend.Syntax +import Solcore.Frontend.TypeInference.Id +import Solcore.Frontend.TypeInference.NameSupply +import Solcore.Frontend.TypeInference.TcEnv(TcEnv(typeTable),TypeInfo(..)) +import Solcore.Frontend.TypeInference.TcUnify(typesDoNotUnify) + + +-- ** Specialisation state and monad +-- SpecState and SM are meant to be local to this module. +type Table a = Map.Map Name a +emptyTable :: Table a +emptyTable = Map.empty + +type Resolution = (Ty, TcFunDef) +data SpecState = SpecState + { spResTable :: Table [Resolution] + , specTable :: Table TcFunDef + , spTypeTable :: Table TypeInfo + , spDataTable :: Table DataTy + , spGlobalEnv :: TcEnv + , splocalEnv :: Table Ty + , spSubst :: TVSubst + , spDebug :: Bool + , spNS :: NameSupply + } + + +type SM = StateT SpecState IO + +getDebug :: SM Bool +getDebug = gets spDebug + +withDebug m = do + savedDebug <- getDebug + modify $ \s -> s { spDebug = True } + a <- m + modify $ \s -> s { spDebug = savedDebug } + return a + +whenDebug m = do + debug <- getDebug + when debug m + +debug :: [String] -> SM () +debug msg = do + enabled <- getDebug + when enabled $ writes msg + +runSM :: Bool -> TcEnv -> SM a -> IO a +runSM debugp env m = evalStateT m (initSpecState debugp env) + + +-- | `withLocalState` runs a computation with a local state +-- local changes are discarded, with the exception of the `specTable` and name supply +withLocalState :: SM a -> SM a +withLocalState m = do + s <- get + a <- m + spTable <- gets specTable + ns <- gets spNS + put s + modify $ \s -> s { specTable = spTable, spNS = ns } + return a + +initSpecState :: Bool ->TcEnv -> SpecState +initSpecState debugp env = SpecState + { spResTable = emptyTable + , specTable = emptyTable + , spTypeTable = typeTable env + , spDataTable = Map.empty + , spGlobalEnv = env + , splocalEnv = emptyTable + , spSubst = emptyTVSubst + , spDebug = debugp + , spNS = namePool + } + + +{- +-- make type variables flexible by replacing them with metas +flex :: Ty -> Ty +flex (TyVar (TVar n)) = Meta (MetaTv n) +flex (TyCon cn tys) = TyCon cn (map flex tys) +flex t = t + +-- make all type variables flexible in a syntactic construct +flexAll :: Data a => a -> a +flexAll = everywhere (mkT flex) +-} + +-- | A signature forall tvs . t is considered ambiguous if `tvs \\ FTV(t) /= mempty` +-- this is should be the same as `FTV(body) \\ FTV(t) /= {}` +-- returns list of ambiguous variables +ambiguousVarsInSig :: HasTV a => Signature a -> [Tyvar] +ambiguousVarsInSig sig = sigVars sig \\ freetv (sigParams sig, sigReturn sig) + +addSpecialisation :: Name -> TcFunDef -> SM () +addSpecialisation name fd = modify $ \s -> s { specTable = Map.insert name fd (specTable s) } + +lookupSpecialisation :: Name -> SM (Maybe TcFunDef) +lookupSpecialisation name = gets (Map.lookup name . specTable) + +addResolution :: Name -> Ty -> TcFunDef -> SM () +addResolution name ty fun = do + -- debug ["+ addResolution ", pretty name, "@", pretty ty, " |-> ", shortName fun] + let sig = funSignature fun + reportAmbiguousVars sig + modify $ \s -> s { spResTable = Map.insertWith (++) name [(ty, fun)] (spResTable s) } + where + reportAmbiguousVars sig = do + let vars = ambiguousVarsInSig sig + let scheme = schemeOfTcSignature sig + unless (null vars) $ nopanics [ "Error: function ", pretty name + ," cannot be specialised because it has an ambiguous type:\n " + , pretty scheme + ,"\n variables: ", prettys vars + ,"\n do not occur in the argument/result types." + ] + +lookupResolution :: Name -> Ty -> SM (Maybe (TcFunDef, Ty, TVSubst)) +lookupResolution name ty = gets (Map.lookup name . spResTable) >>= findMatch ty where + str :: Pretty a => a -> String + str = pretty + findMatch :: Ty -> Maybe [Resolution] -> SM (Maybe (TcFunDef, Ty, TVSubst)) + findMatch etyp (Just res) = do + debug ["|> findMatch ", pretty etyp, " in ", prettysWith pprRes res] + firstMatch etyp res + findMatch _ Nothing = return Nothing + firstMatch :: Ty -> [Resolution] -> SM (Maybe (TcFunDef, Ty, TVSubst)) + firstMatch etyp [] = return Nothing + firstMatch etyp ((t,e):rest) + | Right subst <- specmgu t etyp = do -- TESTME: match is to weak for MPTC, but isn't mgu too strong? + debug ["< lookupRes - match found for ", str name, ": ", str t, " ~ ", str etyp, " => ", str subst] + return (Just (e, t, subst)) + | otherwise = firstMatch etyp rest + +pprRes :: Resolution -> Doc +pprRes(ty, fd) = ppr ty <+> text ":" <+> text(shortName fd) + +getSpSubst :: SM TVSubst +getSpSubst = gets spSubst + +putSpSubst :: TVSubst -> SM () +putSpSubst subst = modify $ \s -> s { spSubst = subst } +extSpSubst :: TVSubst -> SM () + +extSpSubst subst = modify $ \s -> s { spSubst = spSubst s <> subst } + +atCurrentSubst :: HasTV a => a -> SM a +atCurrentSubst a = flip applytv a <$> getSpSubst + +addData :: DataTy -> SM () +addData dt = modify (\s -> s { spDataTable = Map.insert (dataName dt) dt (spDataTable s) }) + +spNewName :: SM Name +spNewName = do + s <- get + let (n, ns) = newName (spNS s) + put s { spNS = ns } + pure (addPrefix "_" n) + +-- data Name = Name String | QualName Name String +addPrefix :: String -> Name -> Name +addPrefix p (Name s) = Name (p ++ s) +addPrefix p (QualName q s) = QualName q (p ++ s) + +specmgu :: Ty -> Ty -> Either String TVSubst +specmgu (TyCon n ts) (TyCon n' ts') + | n == n' && length ts == length ts' = + specsolve (zip ts ts') mempty +specmgu (TyVar v) t = varBind v t +specmgu t (TyVar v) = varBind v t +specmgu t1 t2 = typesDoNotUnify t1 t2 + +varBind :: (MonadError String m) => Tyvar -> Ty -> m TVSubst +varBind v t + | t == TyVar v = return mempty + | v `elem` freetv t = infiniteTyErr v t + | otherwise = do + return (v |-> t) + where + infiniteTyErr w u = throwError $ + unwords + [ "Cannot construct the infinite type:" + , pretty w + , "~" + , pretty u + ] + +specsolve :: [(Ty, Ty)] -> TVSubst -> Either String TVSubst +specsolve [] s = pure s +specsolve ((t1, t2) : ts) s = + do + s1 <- specmgu (applytv s t1) (applytv s t2) + s2 <- specsolve ts s1 + pure (s2 <> s1) + +newtype TVSubst + = TVSubst { unTVSubst :: [(Tyvar, Ty)] } deriving (Eq, Show) + +restrict :: TVSubst -> [Tyvar] -> TVSubst +restrict (TVSubst s) vs + = TVSubst [(v,t) | (v,t) <- s, v `notElem` vs] + +emptyTVSubst :: TVSubst +emptyTVSubst = TVSubst [] + +-- composition operators +-- apply (s1 <> s2) t = apply s1 (apply s2 t) +{- +-- >>> let [a,b,c,x,y] = map (TVar . Name) (Prelude.words "a b c x y") +-- >>> [a,b,c,x,y] +-- [TVar {var = a},TVar {var = b},TVar {var = c},TVar {var = x},TVar {var = y}] +ghci> let [ta,tb,tc,tx,ty] = map TyVar [a,b,c,x,y] +ghci> [ta, tb, tc, tx, ty] +[TyVar (TVar {var = a}),TyVar (TVar {var = b}),TyVar (TVar {var = c}),TyVar (TVar {var = x}),TyVar (TVar {var = y})] + +ghci> let s1 = TVSubst [(a,tx), (b,ty)] +ghci> let s2 = TVSubst [(a,tb), (b,tc), (c,ta)] +ghci> s1 <> s1 +TVSubst {unTVSubst = [(TVar {var = a},TyVar (TVar {var = x})),(TVar {var = b},TyVar (TVar {var = y}))]} +ghci> s1 <> s2 +TVSubst {unTVSubst = [(TVar {var = a},TyVar (TVar {var = y})),(TVar {var = b},TyVar (TVar {var = c})),(TVar {var = c},TyVar (TVar {var = x}))]} +ghci> s2 <> s2 +TVSubst {unTVSubst = [(TVar {var = a},TyVar (TVar {var = c})),(TVar {var = b},TyVar (TVar {var = a})),(TVar {var = c},TyVar (TVar {var = b}))]} +ghci> s2 <> s2 <> s2 +TVSubst {unTVSubst = [(TVar {var = a},TyVar (TVar {var = a})),(TVar {var = b},TyVar (TVar {var = b})),(TVar {var = c},TyVar (TVar {var = c}))]} +-} +instance Semigroup TVSubst where + s1 <> s2 = TVSubst (outer ++ inner) + where + outer = [(u, applytv s1 t) | (u, t) <- unTVSubst s2] + inner = [(v,t) | (v,t) <- unTVSubst s1, v `notElem` dom2] + dom2 = map fst (unTVSubst s2) + +instance Monoid TVSubst where + mempty = emptyTVSubst + +(|->) :: Tyvar -> Ty -> TVSubst +u |-> t = TVSubst [(u, t)] + +instance Pretty TVSubst where + ppr = braces . commaSep . map go . unTVSubst + where + go (v,t) = ppr v <+> text "|->" <+> ppr t + +class Data a => HasTV a where + applytv :: TVSubst -> a -> a + applytv s = everywhere (mkT (applytv @Ty s)) + + freetv :: a -> [Tyvar] -- free variables + freetv = everything (<>) (mkQ mempty (freetv @Ty)) + + renametv :: a -> SM (a, TVRenaming) + renametv a = pure (a, mempty) + + applyRenaming :: TVRenaming -> a -> a + applyRenaming r = everywhere (mkT (renameTV r)) + +instance HasTV Ty where + applytv (TVSubst s) t@(TyVar v) + = maybe t id (lookup v s) + applytv s (TyCon n ts) + = TyCon n (applytv s ts) + applytv _ t = t + + freetv (TyVar v@(TVar _)) = [v] + freetv (TyCon _ ts) = freetv ts + freetv _ = [] + +instance HasTV a => HasTV [a] where + applytv s = map (applytv s) + freetv = foldr (union . freetv) mempty + +instance HasTV a => HasTV (Maybe a) where + applytv s = fmap (applytv s) + freetv = maybe [] freetv + +instance (HasTV a, HasTV b) => HasTV (a,b) where -- defaults + +{- +instance (HasTV a, HasTV b, HasTV c) => HasTV (a,b,c) where + applytv s (z,x,y) = (applytv s z, applytv s x, applytv s y) + freetv (z,x,y) = freetv z `union` freetv x `union` freetv y + +instance (HasTV a, HasTV b) => HasTV (a,b) where + applytv s (x,y) = (applytv s x, applytv s y) + freetv (x,y) = freetv x `union` freetv y +-} + +instance HasTV Id where + applytv s (Id n t) = Id n (applytv s t) + freetv (Id _ t) = freetv t + +instance HasTV a => HasTV (Param a) where -- defaults +instance HasTV a => HasTV (Exp a) where -- defaults +instance HasTV a => HasTV (Stmt a) where -- defaults + +instance HasTV (Pat Id) where + + +instance HasTV (Signature Id) where + applytv s = everywhere (mkT (applytv @Ty s)) + freetv sig = (everything (<>) (mkQ mempty (freetv @Ty))) sig \\ sigVars sig + renametv sig = do + renaming <- foldM addRenaming mempty (sigVars sig) + pure (applyRenaming renaming sig, renaming) + + +{- +data FunDef a + = FunDef { + funSignature :: Signature a + , funDefBody :: [Stmt a] + } deriving (Eq, Ord, Show, Data, Typeable) +-} + +instance HasTV (FunDef Id) where + freetv fd = (everything (<>) (mkQ mempty (freetv @Ty))) fd \\ sigVars (funSignature fd) + renametv fd = do + let sig = funSignature fd + renaming <- foldM addRenaming mempty (sigVars sig) + let subst = toTVS renaming + let sig' = applytv subst sig + let body' = applytv subst (funDefBody fd) + pure(FunDef sig' body', renaming) + +addRenaming :: TVRenaming -> Tyvar -> SM TVRenaming +addRenaming b a = do + fresh <- spNewName + pure (TVR [(a, TVar fresh)] <> b) + +newtype TVRenaming + = TVR { unTVR :: [(Tyvar, Tyvar)] } deriving (Eq, Show) + +instance Pretty TVRenaming where + ppr = braces . commaSep . map go . unTVR + where + go (v,t) = ppr v <+> text "|->" <+> ppr t + +-- composition operators +-- apply (s1 <> s2) t = apply s1 (apply s2 t) +-- renameTy ([(a,x) (b,y)] <> [(a,b), (b,c), (c,a)]) (a :-> b :-> c) +-- = renameTy ([(a,x) (b,y)] (b :-> c :-> a) +-- = y :-> c :-> x +-- Hence ([(a,x) (b,y)] <> [(a,b), (b,c), (c,a)]) = [(a,y), (b,c), (c,x)] +-- +-- >>> let [a,b,c,x,y] = map (TVar . Name) (Prelude.words "a b c x y") +-- >>> [a,b,c,x,y] +-- [TVar {var = a},TVar {var = b},TVar {var = c},TVar {var = x},TVar {var = y}] +-- >>> let r1 = TVR [(a,x), (b,y)] +-- >>> let r2 = TVR [(a,b), (b,c), (c,a)] +-- >>> r1 <> r1 +-- TVR {unTVR = [(TVar {var = a},TVar {var = x}),(TVar {var = b},TVar {var = y})]} +-- >>> r1 <> r2 +-- TVR {unTVR = [(TVar {var = a},TVar {var = y}),(TVar {var = b},TVar {var = c}),(TVar {var = c},TVar {var = x})]} +-- >>> r2 <> r2 +-- TVR {unTVR = [(TVar {var = a},TVar {var = c}),(TVar {var = b},TVar {var = a}),(TVar {var = c},TVar {var = b})]} +-- >>> r2 <> r2 <> r2 +-- TVR {unTVR = []} +-- >>> r1 <> mempty +-- TVR {unTVR = [(TVar {var = a},TVar {var = x}),(TVar {var = b},TVar {var = y})]} + + +instance Semigroup TVRenaming where + r1 <> r2 = TVR (filter (uncurry (/=)) (outer ++ inner)) + where + outer = [(u, renameTV r1 v) | (u, v) <- unTVR r2] + inner = [(v, t) | (v, t) <- unTVR r1, v `notElem` domR2] + domR2 = map fst (unTVR r2) + +instance Monoid TVRenaming where + mempty = TVR mempty + +toTVS :: TVRenaming -> TVSubst +toTVS = TVSubst . map (fmap TyVar) . unTVR + +fromTVS :: TVSubst -> TVRenaming +fromTVS = TVR . map (fmap unTyVar) . unTVSubst where + unTyVar (TyVar x) = x + unTyVar t = error("fromTVS: " ++ pretty t ++ "is not a type variable") + +renameTV :: TVRenaming -> Tyvar -> Tyvar +renameTV (TVR r) v = fromMaybe v (lookup v r) + +renameTy :: TVRenaming -> Ty -> Ty +renameTy = applyRenaming + +renameSubst :: TVRenaming -> TVSubst -> TVSubst +renameSubst r = TVSubst . map rename . unTVSubst where + rename (v, t) = (renameTV r v, renameTy r t) diff --git a/src/Solcore/Backend/Specialise.hs b/src/Solcore/Backend/Specialise.hs index 6dee19e7..67b2dbd4 100644 --- a/src/Solcore/Backend/Specialise.hs +++ b/src/Solcore/Backend/Specialise.hs @@ -1,193 +1,71 @@ -- {-# LANGUAGE DefaultSignatures #-} -module Solcore.Backend.Specialise where --(specialiseCompUnit, typeOfTcExp) where -{- * Specialisation -Create specialised versions of polymorphic and overloaded functions. -This is meant to be run on typed and defunctionalised code, so no higher-order functions. +{-| +Module: Solcore.Backend.Specialise +Description: Monomorphization pass - eliminates polymorphism via specialization + +This module implements whole-program specialization (monomorphization) that transforms +polymorphic and type class-overloaded code into concrete, monomorphic definitions. + += Algorithm Overview + +The specialization process: + + 1. Builds a resolution table mapping (name, type) → definition + 2. Analyzes call sites to discover needed type instantiations + 3. Creates specialized versions with mangled names (e.g., map$word, map$bool) + 4. Resolves type class instances to concrete implementations + 5. Recursively specializes all called functions + += Pipeline Position + +Runs after type checking and defunctionalization, but before Core emission. +Requires whole-program analysis (must see all code at once). + +Input: Typed AST with polymorphic functions and type class constraints +Output: Monomorphic AST with all type variables eliminated + -} +module Solcore.Backend.Specialise(specialiseCompUnit) where -import Common.Monad import Control.Applicative import Control.Monad -import Control.Monad.Except -import Control.Monad.Reader import Control.Monad.State -import Data.Generics -import Data.List(intercalate, union, (\\)) -import Data.Maybe(fromMaybe) +import Data.List(intercalate) import qualified Data.Map as Map -import GHC.Stack + +import Common.Monad +import Solcore.Backend.Retype +import Solcore.Backend.SpecMonad import Solcore.Desugarer.IfDesugarer(desugaredBoolTy) +import Solcore.Frontend.Pretty.ShortName import Solcore.Frontend.Pretty.SolcorePretty -import Solcore.Frontend.Syntax +import Solcore.Frontend.Syntax hiding(name, decls) import Solcore.Frontend.TypeInference.Id ( Id(..) ) -import Solcore.Frontend.TypeInference.NameSupply -import Solcore.Frontend.TypeInference.TcEnv(TcEnv(..),TypeInfo(..)) -import qualified Solcore.Frontend.TypeInference.TcSubst as TcSubst -import Solcore.Frontend.TypeInference.TcUnify(typesDoNotUnify) -import Solcore.Frontend.Pretty.ShortName +import Solcore.Frontend.TypeInference.TcEnv(TcEnv) import Solcore.Primitives.Primitives -import System.Exit -import Common.Pretty - --- ** Specialisation state and monad --- SpecState and SM are meant to be local to this module. -type Table a = Map.Map Name a -emptyTable :: Table a -emptyTable = Map.empty - -type TcFunDef = FunDef Id -type TcExp = Exp Id - -type Resolution = (Ty, TcFunDef) -data SpecState = SpecState - { spResTable :: Table [Resolution] - , specTable :: Table TcFunDef - , spTypeTable :: Table TypeInfo - , spDataTable :: Table DataTy - , spGlobalEnv :: TcEnv - , splocalEnv :: Table Ty - , spSubst :: TVSubst - , spDebug :: Bool - , spNS :: NameSupply - } - - -type SM = StateT SpecState IO - -getDebug :: SM Bool -getDebug = gets spDebug - -withDebug m = do - savedDebug <- getDebug - modify $ \s -> s { spDebug = True } - a <- m - modify $ \s -> s { spDebug = savedDebug } - return a - -whenDebug m = do - debug <- getDebug - when debug m - -debug :: [String] -> SM () -debug msg = do - enabled <- getDebug - when enabled $ writes msg - -runSM :: Bool -> TcEnv -> SM a -> IO a -runSM debugp env m = evalStateT m (initSpecState debugp env) - --- prettys :: Pretty a => [a] -> String --- prettys = render . brackets . commaSep . map ppr - --- | `withLocalState` runs a computation with a local state --- local changes are discarded, with the exception of the `specTable` and name supply -withLocalState :: SM a -> SM a -withLocalState m = do - s <- get - a <- m - spTable <- gets specTable - ns <- gets spNS - put s - modify $ \s -> s { specTable = spTable, spNS = ns } - return a - -initSpecState :: Bool ->TcEnv -> SpecState -initSpecState debugp env = SpecState - { spResTable = emptyTable - , specTable = emptyTable - , spTypeTable = typeTable env - , spDataTable = Map.empty - , spGlobalEnv = env - , splocalEnv = emptyTable - , spSubst = emptyTVSubst - , spDebug = debugp - , spNS = namePool - } -{- --- make type variables flexible by replacing them with metas -flex :: Ty -> Ty -flex (TyVar (TVar n)) = Meta (MetaTv n) -flex (TyCon cn tys) = TyCon cn (map flex tys) -flex t = t - --- make all type variables flexible in a syntactic construct -flexAll :: Data a => a -> a -flexAll = everywhere (mkT flex) --} +------------------------------------------------------------------------------- +-- Constants +------------------------------------------------------------------------------- + +-- | The entry point name for contract runtime code +entryPointName :: Name +entryPointName = Name "main" + +-- | The constructor function name +constructorName :: Name +constructorName = Name "constructor" + +-- | Built-in revert primitive name +revertBuiltin :: Name +revertBuiltin = Name "revert" --- | A signature forall tvs . t is considered ambiguous if `tvs \\ FTV(t) /= mempty` --- this is should be the same as `FTV(body) \\ FTV(t) /= {}` --- returns list of ambiguous variables -ambiguousVarsInSig :: HasTV a => Signature a -> [Tyvar] -ambiguousVarsInSig sig = sigVars sig \\ freetv (sigParams sig, sigReturn sig) - -addSpecialisation :: Name -> TcFunDef -> SM () -addSpecialisation name fd = modify $ \s -> s { specTable = Map.insert name fd (specTable s) } - -lookupSpecialisation :: Name -> SM (Maybe TcFunDef) -lookupSpecialisation name = gets (Map.lookup name . specTable) - -addResolution :: Name -> Ty -> TcFunDef -> SM () -addResolution name ty fun = do - -- debug ["+ addResolution ", pretty name, "@", pretty ty, " |-> ", shortName fun] - let sig = funSignature fun - reportAmbiguousVars sig - modify $ \s -> s { spResTable = Map.insertWith (++) name [(ty, fun)] (spResTable s) } - where - reportAmbiguousVars sig = do - let vars = ambiguousVarsInSig sig - let scheme = schemeOfTcSignature sig - unless (null vars) $ nopanics [ "Error: function ", pretty name - ," cannot be specialised because it has an ambiguous type:\n " - , pretty scheme - ,"\n variables: ", prettys vars - ,"\n do not occur in the argument/result types." - ] -lookupResolution :: Name -> Ty -> SM (Maybe (TcFunDef, Ty, TVSubst)) -lookupResolution name ty = gets (Map.lookup name . spResTable) >>= findMatch ty where - str :: Pretty a => a -> String - str = pretty - findMatch :: Ty -> Maybe [Resolution] -> SM (Maybe (TcFunDef, Ty, TVSubst)) - findMatch etyp (Just res) = do - debug ["|> findMatch ", pretty etyp, " in ", prettys res] - firstMatch etyp res - findMatch _ Nothing = return Nothing - firstMatch :: Ty -> [Resolution] -> SM (Maybe (TcFunDef, Ty, TVSubst)) - firstMatch etyp [] = return Nothing - firstMatch etyp ((t,e):rest) - | Right subst <- specmgu t etyp = do -- TESTME: match is to weak for MPTC, but isn't mgu too strong? - debug ["< lookupRes - match found for ", str name, ": ", str t, " ~ ", str etyp, " => ", str subst] - return (Just (e, t, subst)) - | otherwise = firstMatch etyp rest - -getSpSubst :: SM TVSubst -getSpSubst = gets spSubst - -putSpSubst :: TVSubst -> SM () -putSpSubst subst = modify $ \s -> s { spSubst = subst } -extSpSubst :: TVSubst -> SM () - -extSpSubst subst = modify $ \s -> s { spSubst = spSubst s <> subst } - -atCurrentSubst :: HasTV a => a -> SM a -atCurrentSubst a = flip applytv a <$> getSpSubst - -addData :: DataTy -> SM () -addData dt = modify (\s -> s { spDataTable = Map.insert (dataName dt) dt (spDataTable s) }) - -spNewName :: SM Name -spNewName = do - s <- get - let (n, ns) = newName (spNS s) - put s { spNS = ns } - pure (addPrefix "_" n) - --- data Name = Name String | QualName Name String -addPrefix :: String -> Name -> Name -addPrefix p (Name s) = Name (p ++ s) -addPrefix p (QualName q s) = QualName q (p ++ s) +-- | Placeholder type variable used for initial resolution lookups +anyTypePlaceholder :: Ty +anyTypePlaceholder = TyVar (TVar (Name "any")) +------------------------------------------------------------------------------- +-- Main Entry Point ------------------------------------------------------------------------------- specialiseCompUnit :: CompUnit Id -> Bool -> TcEnv -> IO (CompUnit Id) @@ -196,52 +74,107 @@ specialiseCompUnit compUnit debugp env = runSM debugp env do topDecls <- concat <$> forM (contracts compUnit) specialiseTopDecl return $ compUnit { contracts = topDecls } +------------------------------------------------------------------------------- +-- Resolution Table Building +------------------------------------------------------------------------------- +-- Build a table mapping (name, type) to function definitions +-- This enables looking up which definition to use when specializing a call + +-- | Type class for declarations that can contribute to the resolution table +class HasResolutions decl where + addResolutions :: decl -> SM () + addGlobalResolutions :: CompUnit Id -> SM () -addGlobalResolutions compUnit = forM_ (contracts compUnit) addDeclResolutions +addGlobalResolutions compUnit = forM_ (contracts compUnit) addResolutions -addDeclResolutions :: TopDecl Id -> SM () -addDeclResolutions (TInstDef inst) = addInstResolutions inst -addDeclResolutions (TFunDef fd) = addFunDefResolution fd -addDeclResolutions (TDataDef dt) = addData dt -addDeclResolutions (TMutualDef decls) = forM_ decls addDeclResolutions -addDeclResolutions _ = return () +instance HasResolutions (TopDecl Id) where + addResolutions (TInstDef inst) = addInstResolutions inst + addResolutions (TFunDef fd) = addFunDefResolution fd + addResolutions (TDataDef dt) = addData dt + addResolutions (TMutualDef decls) = mapM_ addResolutions decls + addResolutions _ = return () +instance HasResolutions (ContractDecl Id) where + addResolutions (CFunDecl fd) = addFunDefResolution fd + addResolutions (CDataDecl dt) = addData dt + addResolutions (CMutualDecl decls) = mapM_ addResolutions decls + addResolutions _ = return () addInstResolutions :: Instance Id -> SM () addInstResolutions inst = forM_ (instFunctions inst) (addMethodResolution (instName inst) (mainTy inst)) +addFunDefResolution :: FunDef Id -> SM () +addFunDefResolution fd = do + let sig = funSignature fd + let name = sigName sig + let funType = typeOfTcFunDef fd + addResolution name funType fd + debug ["+ addDeclResolution: ", show name, " : ", pretty funType] + +addMethodResolution :: Name -> Ty -> TcFunDef -> SM () +addMethodResolution cname ty fd = do + let sig = funSignature fd + let name = sigName sig + let qname = case name of + QualName{} -> name + Name s -> QualName cname s + let name' = specName qname [ty] + let funType = typeOfTcFunDef fd + let fd' = FunDef sig{sigName = name'} (funDefBody fd) + addResolution qname funType fd' + debug ["+ addMethodResolution: ", show qname, " / ", show name', " : ", pretty funType] + +addContractResolutions :: Contract Id -> SM () +addContractResolutions (Contract _name _args cdecls) = + forM_ cdecls addResolutions + +------------------------------------------------------------------------------- +-- Top-Level Specialization +------------------------------------------------------------------------------- + specialiseTopDecl :: TopDecl Id -> SM [TopDecl Id] -specialiseTopDecl (TContr (Contract name args decls)) = withLocalState do - addContractResolutions (Contract name args decls) - -- Runtime code - runtimeDecls <- withLocalState do - forM_ entries specEntry - getSpecialisedDecls - -- Deployer code - modify (\st -> st { specTable = emptyTable }) - deployDecls <- case findConstructor decls of - Just c -> withLocalState do - cname' <- specConstructor c - st <- gets specTable - depDecls <- getSpecialisedDecls - -- use mutual to group constructor with its dependencies - pure [CMutualDecl depDecls] - Nothing -> pure [] +specialiseTopDecl (TContr contract@(Contract name args decls)) = withLocalState $ do + addContractResolutions contract + runtimeDecls <- specialiseRuntimeDecls + deployDecls <- specialiseDeployerCode decls return [TContr (Contract name args (deployDecls ++ runtimeDecls))] - where - entries = ["main"] -- Eventually all public methods - getSpecialisedDecls :: SM [ContractDecl Id] - getSpecialisedDecls = do + where + -- | Specializes all runtime entry points (public methods) + specialiseRuntimeDecls :: SM [ContractDecl Id] + specialiseRuntimeDecls = withLocalState $ do + forM_ [entryPointName] specEntry -- Eventually all public methods + collectSpecialisedDecls + + -- | Specializes the constructor (deployer code) if present + specialiseDeployerCode :: [ContractDecl Id] -> SM [ContractDecl Id] + specialiseDeployerCode contractDecls = do + -- Clear the specialization table for deployer code + modify (\st -> st { specTable = emptyTable }) + case findConstructor contractDecls of + Just c -> specialiseConstructorDecls c + Nothing -> pure [] + + -- | Specializes a constructor and its dependencies + specialiseConstructorDecls :: Constructor Id -> SM [ContractDecl Id] + specialiseConstructorDecls c = withLocalState $ do + _cname' <- specConstructor c + depDecls <- collectSpecialisedDecls + -- Use mutual to group constructor with its dependencies + pure [CMutualDecl depDecls] + + -- | Collects all specialized functions and data types from the current state + collectSpecialisedDecls :: SM [ContractDecl Id] + collectSpecialisedDecls = do st <- gets specTable dt <- gets spDataTable let dataDecls = map (CDataDecl . snd) (Map.toList dt) let funDecls = map (CFunDecl . snd) (Map.toList st) pure (dataDecls ++ funDecls) --- keep datatype defs intact +-- Keep datatype defs intact specialiseTopDecl d@TDataDef{} = pure [d] --- Drop all toplevel decls that are not contracts - we do not need them anymore -specialiseTopDecl decl = pure [] +-- Drop all toplevel decls that are not contracts - we don't need them after specialization +specialiseTopDecl _ = pure [] findConstructor :: [ContractDecl Id] -> Maybe (Constructor Id) findConstructor = foldr (\d -> (getConstructor d <|>)) Nothing @@ -253,9 +186,7 @@ getConstructor _ = Nothing specEntry :: Name -> SM () specEntry name = withLocalState do - let any = TVar (Name "any") - let anytype = TyVar any - mres <- lookupResolution name anytype + mres <- lookupResolution name anyTypePlaceholder case mres of Just (fd, ty, subst) -> do debug ["< resolution: ", show name, " : ", pretty ty, "@", pretty subst] @@ -265,68 +196,49 @@ specEntry name = withLocalState do specConstructor :: Constructor Id -> SM Name specConstructor (Constructor [] body) = do - let sig = Signature [] [] (Name "constructor") [] (Just unit) + let sig = Signature [] [] constructorName [] (Just unit) let fd = FunDef sig body specFunDef fd -specConstructor (Constructor params body) = error "Unsupported constructor" - -addContractResolutions :: Contract Id -> SM () -addContractResolutions (Contract name args decls) = do - forM_ decls addCDeclResolution - -addCDeclResolution :: ContractDecl Id -> SM () -addCDeclResolution (CFunDecl fd) = addFunDefResolution fd -addCDeclResolution (CDataDecl dt) = addData dt -addCDeclResolution (CMutualDecl decls) = forM_ decls addCDeclResolution -addCDeclResolution _ = return () - -addFunDefResolution fd = do - let sig = funSignature fd - let name = sigName sig - let funType = typeOfTcFunDef fd - addResolution name funType fd - debug ["+ addDeclResolution: ", show name, " : ", pretty funType] +specConstructor _ = error "Unsupported constructor" -addMethodResolution :: Name -> Ty -> TcFunDef -> SM () -addMethodResolution cname ty fd = do - let sig = funSignature fd - let name = sigName sig - let qname = case name of - QualName{} -> name - Name s -> QualName cname s - let name' = specName qname [ty] - let funType = typeOfTcFunDef fd - let fd' = FunDef sig{sigName = name'} (funDefBody fd) - addResolution qname funType fd' - debug ["+ addMethodResolution: ", show qname, " / ", show name', " : ", pretty funType] +------------------------------------------------------------------------------- +-- Expression Specialization +------------------------------------------------------------------------------- -- | `specExp` specialises an expression to given type specExp :: TcExp -> Ty -> SM TcExp -specExp e@(Call Nothing i args) ty = do +specExp (Call Nothing i args) ty = do -- debug ["> specExp (Call): ", pretty e, " : ", pretty (idType i), " ~> ", pretty ty] (i', args') <- specCall i args ty let e' = Call Nothing i' args' -- debug ["< specExp (Call): ", pretty e'] return e' -specExp e@(Con i@(Id n conty) es) ty = do - let t = typeOfTcExp e +specExp (Con i es) ty = do + -- let t = typeOfTcExp e -- debug ["> specConApp: ", pretty e, " : ", pretty t, " ~> ", pretty ty] (i' , es') <- specConApp i es ty let e' = Con i' es' return e' -specExp e@(Cond e1 e2 e3) ty = do +specExp (Cond e1 e2 e3) ty = do e1' <- specExp e1 desugaredBoolTy e2' <- specExp e2 ty e3' <- specExp e3 ty pure (Cond e1' e2' e3') -specExp e@(Var (Id n t)) ty = pure (Var (Id n ty)) -specExp e@(FieldAccess me fld) ty = error("Specialise: FieldAccess not implemented for" ++ pretty e) -specExp e@(TyExp e1 _) ty = specExp e1 ty -specExp e ty = atCurrentSubst e -- FIXME +specExp (Var (Id n _t)) ty = pure (Var (Id n ty)) +specExp e@FieldAccess{} _ty = + error $ unlines + [ "Specialization failed: FieldAccess should have been desugared" + , "Expression: " ++ pretty e + , "" + , "This is a compiler bug - FieldAccess expressions should not reach specialization." + , "Check that field access desugaring ran before specialization in the pipeline." + ] +specExp (TyExp e1 _) ty = specExp e1 ty +specExp e _ty = atCurrentSubst e -- FIXME specConApp :: Id -> [TcExp] -> Ty -> SM (Id, [TcExp]) -- specConApp i@(Id n conTy) [] ty = pure (i, []) -specConApp i@(Id n conTy) args ty = do +specConApp i@(Id _n conTy) args ty = do subst <- getSpSubst let argTypes = map typeOfTcExp args let argTypes' = applytv subst argTypes @@ -338,104 +250,166 @@ specConApp i@(Id n conTy) args ty = do debug ["< specConApp: ", prettyConApp i args, " ~> ", prettyConApp i' args'] return (i', args') --- | Specialise a function call --- given actual arguments and the expected result type +-- | Specializes a function call to concrete types +-- Given actual arguments and the expected result type specCall :: Id -> [TcExp] -> Ty -> SM (Id, [TcExp]) -specCall i@(Id (Name "revert") e) args ty = pure (i, args) -- FIXME -specCall i args ty = do - i' <- atCurrentSubst i - ty' <- atCurrentSubst ty - -- debug ["> specCall: ", pretty i', show args, " : ", pretty ty'] - let name = idName i' - let argTypes = map typeOfTcExp args - argTypes' <- atCurrentSubst argTypes - let typedArgs = zip args argTypes' - args' <- forM typedArgs (uncurry specExp) - let funType = foldr (:->) ty' argTypes' - debug ["> specCall: ", show name, " : ", pretty funType] - mres <- lookupResolution name funType - case mres of - Just (fd, ty, phi) -> do - debug ["< resolution: ", show name, "~>", shortName fd, " : ", pretty ty, "@", pretty phi] - extSpSubst phi - -- ty' <- atCurrentSubst ty - subst <- getSpSubst - let ty' = applytv subst ty - ensureClosed ty' (Call Nothing i args) subst - name' <- specFunDef fd - debug ["< specCall: ", pretty name'] - args'' <- atCurrentSubst args' - return (Id name' ty', args'') - Nothing -> do - panics ["! specCall: no resolution found for ", show name, " : ", pretty funType] - return (i, args') +specCall i@(Id name _e) args ty + | isBuiltin name = pure (i, args) -- Built-ins don't need specialization + | otherwise = do + -- Apply current substitution + i' <- atCurrentSubst i + expectedResultTy <- atCurrentSubst ty + + -- Specialize arguments + let argTypes = map typeOfTcExp args + argTypes' <- atCurrentSubst argTypes + specializedArgs <- zipWithM specExp args argTypes' + + -- Compute expected function type + let funType = foldr (:->) expectedResultTy argTypes' + debug ["> specCall: ", show (idName i'), " : ", pretty funType] + + -- Look up and apply resolution + resolveAndSpecialize (idName i') funType specializedArgs i args + where + _ty = _ty -- Suppress unused variable warning + +-- | Check if a name is a built-in primitive that doesn't need specialization +isBuiltin :: Name -> Bool +isBuiltin name = name == revertBuiltin + +-- | Resolves a function name to its definition and specializes it +resolveAndSpecialize :: Name -> Ty -> [TcExp] -> Id -> [TcExp] -> SM (Id, [TcExp]) +resolveAndSpecialize name funType specializedArgs originalId originalArgs = do + mres <- lookupResolution name funType + case mres of + Just (fd, fty, substitution) -> applyResolution fd fty substitution + Nothing -> handleMissingResolution name funType originalId specializedArgs where - guardSimpleType :: Ty -> SM () - guardSimpleType (Meta _) = panics ["specCall ", pretty i, ": polymorphic result type"] - guardSimpleType (TyVar _) = panics ["specCall ", pretty i, ": polymorphic result type"] - guardSimpleType (_ :-> _) = panics ["specCall ", pretty i, ": function result type"] - guardSimpleType _ = pure () - --- | `specFunDef` specialises a function definition --- to the given type of the form `arg1Ty -> arg2Ty -> ... -> resultTy` --- first lookup if a specialisation to the given type exists --- if not, look for a resolution (definition matching the expected type) --- create a new specialisation of it and record it in `specTable` --- returns name of the specialised function + -- Apply a found resolution to specialize the function + applyResolution :: TcFunDef -> Ty -> TVSubst -> SM (Id, [TcExp]) + applyResolution fd fty phi = do + debug ["< resolution: ", show name, "~>", shortName fd, " : ", pretty funType, "@", pretty phi] + extSpSubst phi + subst <- getSpSubst + let instantiatedType = applytv subst fty + ensureClosed instantiatedType (Call Nothing originalId originalArgs) subst + specializedName <- specFunDef fd + debug ["< specCall: ", pretty specializedName] + finalArgs <- atCurrentSubst specializedArgs + return (Id specializedName instantiatedType, finalArgs) + + -- Handle the case where no resolution is found + handleMissingResolution :: Name -> Ty -> Id -> [TcExp] -> SM (Id, [TcExp]) + handleMissingResolution n ft i args = do + void $ panics + [ "Specialization failed: No resolution found for function call" + , "\nFunction: " ++ pretty n + , "\nRequired type: " ++ pretty ft + , "\n" + , "\nPossible causes:" + , "\n - Function is not defined or not in scope" + , "\n - Type mismatch between call site and definition" + , "\n - Missing instance for type class constraint" + , "\n" + , "\nCheck that the function is defined and that type inference is correct." + ] + return (i, args) + +------------------------------------------------------------------------------- +-- Function Specialization +------------------------------------------------------------------------------- + +-- | Specializes a function definition to concrete types +-- +-- Algorithm: +-- 1. Rename bound type variables to avoid capture +-- 2. Apply current substitution to eliminate type variables +-- 3. Generate mangled name based on concrete types (e.g., map$word) +-- 4. Check if specialization already exists (memoization) +-- 5. If not, create placeholder (breaks recursive loops), specialize body, record result +-- +-- Returns the mangled name of the specialized function specFunDef :: TcFunDef -> SM Name -specFunDef fd0 = withLocalState do - -- first, rename bound variables - (fd, renamingSubst) <- renametv fd0 - let renaming = fromTVS renamingSubst - let sig0 = funSignature fd - let sig = funSignature fd - let name = sigName sig - let funType = typeOfTcFunDef fd - let tvs = freetv funType - subst <- renameSubst renaming <$> getSpSubst - putSpSubst subst - let tvs' = applytv subst (map TyVar tvs) - debug ["> specFunDef ", pretty name, " : ", pretty funType, " tvs'=", prettys tvs', " subst=", pretty subst] - let name' = specName name tvs' - let ty' = applytv subst funType - mspec <- lookupSpecialisation name' - case mspec of - Just fd' -> return name' - Nothing -> do - let sig' = applytv subst (funSignature fd) - -- add a placeholder first to break loops - let placeholder = FunDef sig' [] - addSpecialisation name' placeholder - body' <- specBody (funDefBody fd) - let fd' = FunDef sig'{sigName = name'} body' - debug ["+ specFunDef: adding specialisation ", show name', " : ", pretty ty'] - addSpecialisation name' fd' - return name' +specFunDef originalDef = withLocalState $ do + -- Step 1: Rename bound type variables to avoid capture + (renamedDef, renaming) <- renametv originalDef + let originalSig = funSignature originalDef + let renamedSig = funSignature renamedDef + + -- Step 2: Apply renaming to current substitution + subst <- renameSubst renaming <$> getSpSubst + putSpSubst subst + debug ["> specFunDef raw input: ", pretty originalSig, + " renaming=", pretty renaming, " subst=", pretty subst] + + -- Step 3: Eliminate type variables and type class context + let monomorphicSig = applytv subst renamedSig { sigVars = [], sigContext = [] } + let name = sigName renamedSig + let funType = typeOfTcFunDef renamedDef + + -- Step 4: Compute concrete type instantiation + let freeTypeVars = freetv funType + let concreteTypes = applytv subst (map TyVar freeTypeVars) + debug ["> specFunDef ", pretty name, " : ", pretty renamedSig, + " concreteTypes=", prettys concreteTypes, " subst=", pretty subst] + + -- Step 5: Generate mangled name (e.g., map$word$bool) + let specializedName = specName name concreteTypes + let specializedType = applytv subst funType + + -- Step 6: Check memoization table + memoizedSpec <- lookupSpecialisation specializedName + case memoizedSpec of + Just _ -> return specializedName -- Already specialized + Nothing -> createNewSpecialization specializedName monomorphicSig renamedDef specializedType + where + -- Create and record a new specialization + createNewSpecialization :: Name -> Signature Id -> TcFunDef -> Ty -> SM Name + createNewSpecialization name sig def ty = do + -- Add placeholder first to break infinite recursion in mutually recursive functions + let placeholder = FunDef sig { sigName = name } [] + addSpecialisation name placeholder + + -- Recursively specialize the function body + specializedBody <- specBody (funDefBody def) + let specializedDef = FunDef sig { sigName = name } specializedBody + + debug ["+ specFunDef: adding specialisation ", show name, " : ", pretty ty] + addSpecialisation name specializedDef + return name + +------------------------------------------------------------------------------- +-- Statement Specialization +------------------------------------------------------------------------------- specBody :: [Stmt Id] -> SM [Stmt Id] specBody = mapM specStmt -{- -ensureSimple ty' stmt subst = case ty' of - TyVar _ -> panics [ "specStmt(",pretty stmt,"): polymorphic return type: " - , pretty ty', " subst=", pretty subst] - _ :-> _ -> panics [ "specStmt(",pretty stmt,"): function return type: " - , pretty ty' - ,"\nIn:\n", show stmt - ] - _ -> return () --} - --- | `ensureClosed` checks that a type is closed, i.e. has no free type variables +-- | Ensures that a type is closed (has no free type variables) +-- All types must be fully concrete after specialization for Yul code generation ensureClosed :: Pretty a => Ty -> a -> TVSubst -> SM () ensureClosed ty ctxt subst = do let tvs = freetv ty - unless (null tvs) $ panics ["spec(", pretty ctxt,"): free type vars in ", pretty ty, ": ", show tvs - , " @ subst=", pretty subst] + unless (null tvs) $ panics + [ "Specialization failed: Type still contains free type variables" + , "\nContext: " ++ pretty ctxt + , "\nType: " ++ pretty ty + , "\nFree type variables: " ++ show tvs + , "\nCurrent substitution: " ++ pretty subst + , "\n" + , "\nThis indicates incomplete specialization - all types must be concrete." + , "\nCheck that all polymorphic functions have been properly instantiated." + ] {- let mvs = mv ty - unless (null tvs) $ panics ["spec(", pretty ctxt,"): free meta vars in ", pretty ty, ": ", show mvs - , " @ subst=", pretty subst] + unless (null mvs) $ panics + [ "Specialization failed: Type still contains meta variables" + , "\nContext: " ++ pretty ctxt + , "\nType: " ++ pretty ty + , "\nMeta variables: " ++ show mvs + , "\nCurrent substitution: " ++ pretty subst + ] -} specStmt :: Stmt Id -> SM(Stmt Id) @@ -479,57 +453,93 @@ specStmt (StmtExp e) = do return $ StmtExp e' specStmt (Asm ys) = pure (Asm ys) -specStmt stmt = errors ["specStmt not implemented for: ", show stmt] +specStmt stmt = errors + [ "Specialization failed: Unsupported statement type" + , "Statement: " ++ show stmt + , "" + , "This statement type has not been implemented in the specializer." + ] specMatch :: [Exp Id] -> [([Pat Id], [Stmt Id])] -> SM (Stmt Id) specMatch exps alts = do - subst <- getSpSubst - -- debug ["> specMatch, scrutinee: ", pretty exps, " @ ", pretty subst] exps' <- specScruts exps alts' <- forM alts specAlt - -- debug ["< specMatch, alts': ", show alts'] return $ Match exps' alts' where specAlt (pat, body) = do - -- debug ["specAlt, pattern: ", show pat] - -- debug ["specAlt, body: ", show body] body' <- specBody body pat' <- atCurrentSubst pat return (pat', body') specScruts = mapM specScrut specScrut e = do - subst <- getSpSubst ty <- atCurrentSubst (typeOfTcExp e) e' <- specExp e ty - -- debug ["specScrut: ", show e, " to ", pretty ty, " ~>", show e'] return e' - +------------------------------------------------------------------------------- +-- Name Mangling +------------------------------------------------------------------------------- +-- Generate unique names for specialized functions by encoding type arguments +-- in the function name. This allows multiple specialized versions to coexist. + +-- | Generate a unique mangled name for a specialized function +-- +-- Examples: +-- @map + [] → "map"@ (no type arguments, no mangling) +-- @map + [word] → "map$word"@ +-- @foo + [word, bool] → "foo$word_bool"@ +-- @Bar.baz + [unit] → "Bar_baz$unit"@ +-- @Pair + [word, bool] → "Pair$word_bool"@ specName :: Name -> [Ty] -> Name specName n [] = Name $ flattenQual n specName n ts = Name $ flattenQual n ++ "$" ++ intercalate "_" (map mangleTy ts) +-- | Flatten a qualified name into a string, replacing dots with underscores +-- +-- Examples: +-- @Name "foo" → "foo"@ +-- @QualName (Name "Bar") "baz" → "Bar_baz"@ +-- @QualName (QualName (Name "A") "B") "C" → "A_B_C"@ flattenQual :: Name -> String flattenQual (Name n) = n flattenQual (QualName n s) = flattenQual n ++ "_" ++ s +-- | Mangle a type into a string suitable for name mangling +-- +-- Type variables are omitted (should not appear after substitution). +-- Capitalization is preserved from the original type names. +-- +-- Examples: +-- @word → "word"@ +-- @bool → "bool"@ +-- @() → "unit"@ +-- @List → "ListLwordJ"@ +-- @Pair → "PairLword_boolJ"@ +-- @Map> → "MapLword_ListLboolJJ"@ +-- +-- The delimiters L and J mark the beginning and end of type argument lists, +-- allowing nested generic types to be unambiguously encoded. mangleTy :: Ty -> String -mangleTy (TyVar (TVar (Name n))) = n -mangleTy (Meta (MetaTv (Name n))) = n +mangleTy (TyVar _) = "" -- Type vars should be eliminated by substitution mangleTy (TyCon (Name "()") []) = "unit" mangleTy (TyCon (Name n) []) = n -mangleTy (TyCon (Name n) ts) = n ++ "L" ++ intercalate "_" (map mangleTy ts) ++"J" +mangleTy (TyCon n ts) = flattenQual n ++ embrace mantys + where + mantys = filter (not . null) (map mangleTy ts) + embrace [] = "" + embrace xs = "L" ++ intercalate "_" xs ++ "J" -- L...J delimiters +mangleTy t = error $ "Specialise: mangleTy not implemented for " ++ show t -showId :: Id -> String -showId i = showsId i "" -showsId (Id n t) = shows n . ('@':) . showsPrec 10 t +------------------------------------------------------------------------------- +-- Pretty Printing Utilities +------------------------------------------------------------------------------- prettyId :: Id -> String prettyId = render . pprId pprId :: Id -> Doc pprId (Id n t@TyVar{}) = ppr n <> text "@" <> ppr t -pprId (Id n t@(TyCon cn [])) = ppr n <> "@" <> ppr t +pprId (Id n t@TyCon{}) = ppr n <> "@" <> ppr t pprId (Id n t) = ppr n <> text "@" <> parens(ppr t) pprConApp :: Id -> [TcExp] -> Doc @@ -537,238 +547,3 @@ pprConApp i args = pprId i <> brackets (commaSepList args) prettyConApp :: Id -> [TcExp] -> String prettyConApp i args = render (pprConApp i args) - - -typeOfTcExp :: TcExp -> Ty -typeOfTcExp (Var i) = idType i -typeOfTcExp (Con i []) = idType i -typeOfTcExp e@(Con i args) = go (idType i) args where - go ty [] = ty - go (_ :-> u) (a:as) = go u as - go _ _ = error $ "typeOfTcExp: " ++ show e -typeOfTcExp (Lit (IntLit _)) = word --TyCon "Word" [] -typeOfTcExp exp@(Call Nothing i args) = applyTo args funTy where - funTy = idType i - applyTo [] ty = ty - applyTo (_:as) (_ :-> u) = applyTo as u - applyTo _ _ = error $ concat [ "apply ", pretty i, " : ", pretty funTy - , "to", show $ map pretty args - , "\nIn:\n", show exp - ] -typeOfTcExp (Lam args body (Just tb)) = funtype tas tb where - tas = map typeOfTcParam args -typeOfTcExp (Cond _ _ e) = typeOfTcExp e -typeOfTcExp (TyExp _ ty) = ty -typeOfTcExp e = error $ "typeOfTcExp: " ++ show e - -typeOfTcStmt :: Stmt Id -> Ty -typeOfTcStmt (n := e) = unit -typeOfTcStmt (Let n _ _) = idType n -typeOfTcStmt (StmtExp e) = typeOfTcExp e -typeOfTcStmt (Return e) = typeOfTcExp e -typeOfTcStmt (Match _ ((pat, body):_)) = typeOfTcBody body - -typeOfTcBody :: [Stmt Id] -> Ty -typeOfTcBody [] = unit -typeOfTcBody [s] = typeOfTcStmt s -typeOfTcBody (_:b) = typeOfTcBody b - -typeOfTcParam :: Param Id -> Ty -typeOfTcParam (Typed i t) = idType i -- seems better than t - see issue #6 -typeOfTcParam (Untyped i) = idType i - -typeOfTcSignature :: Signature Id -> Ty -typeOfTcSignature sig = funtype (map typeOfTcParam $ sigParams sig) (returnType sig) where - returnType sig = case sigReturn sig of - Just t -> t - Nothing -> error ("no return type in signature of: " ++ show (sigName sig)) - -schemeOfTcSignature :: Signature Id -> Scheme -schemeOfTcSignature sig@(Signature vs ps n args (Just rt)) - = if all isTyped args - then Forall vs (ps :=> (funtype ts rt)) - else error $ unwords ["Invalid instance member signature:", pretty sig] - where - isTyped (Typed _ _) = True - isTyped _ = False - ts = map (\ (Typed _ t) -> t) args - -typeOfTcFunDef :: TcFunDef -> Ty -typeOfTcFunDef (FunDef sig _) = typeOfTcSignature sig - -pprRes :: Resolution -> Doc --- type Resolution = (Ty, FunDef Id) -pprRes(ty, fd) = ppr ty <+> text ":" <+> text(shortName fd) - -instance Pretty (Ty, FunDef Id) where - ppr = pprRes - -specmgu :: Ty -> Ty -> Either String TVSubst -specmgu (TyCon n ts) (TyCon n' ts') - | n == n' && length ts == length ts' = - specsolve (zip ts ts') mempty -specmgu (TyVar v) t = varBind v t -specmgu t (TyVar v) = varBind v t -specmgu t1 t2 = typesDoNotUnify t1 t2 - -varBind :: (MonadError String m) => Tyvar -> Ty -> m TVSubst -varBind v t - | t == TyVar v = return mempty - | v `elem` freetv t = infiniteTyErr v t - | otherwise = do - return (v |-> t) - where - infiniteTyErr v t = throwError $ - unwords - [ "Cannot construct the infinite type:" - , pretty v - , "~" - , pretty t - ] - -specsolve :: [(Ty, Ty)] -> TVSubst -> Either String TVSubst -specsolve [] s = pure s -specsolve ((t1, t2) : ts) s = - do - s1 <- specmgu (applytv s t1) (applytv s t2) - s2 <- specsolve ts s1 - pure (s2 <> s1) - -newtype TVSubst - = TVSubst { unTVSubst :: [(Tyvar, Ty)] } deriving (Eq, Show) - -restrict :: TVSubst -> [Tyvar] -> TVSubst -restrict (TVSubst s) vs - = TVSubst [(v,t) | (v,t) <- s, v `notElem` vs] - -emptyTVSubst :: TVSubst -emptyTVSubst = TVSubst [] - --- composition operators - -instance Semigroup TVSubst where - s1 <> s2 = TVSubst (outer ++ inner) - where - outer = [(u, applytv s1 t) | (u, t) <- unTVSubst s2] - inner = [(v,t) | (v,t) <- unTVSubst s1, v `notElem` dom2] - dom2 = map fst (unTVSubst s2) - -instance Monoid TVSubst where - mempty = emptyTVSubst - -(|->) :: Tyvar -> Ty -> TVSubst -u |-> t = TVSubst [(u, t)] - -instance Pretty TVSubst where - ppr = braces . commaSep . map go . unTVSubst - where - go (v,t) = ppr v <+> text "|->" <+> ppr t - -class Data a => HasTV a where - applytv :: TVSubst -> a -> a - applytv s = everywhere (mkT (applytv @Ty s)) - - freetv :: a -> [Tyvar] -- free variables - freetv = everything (<>) (mkQ mempty (freetv @Ty)) - - renametv :: a -> SM (a, TVSubst) - renametv a = pure (a, mempty) - -instance HasTV Ty where - applytv (TVSubst s) t@(TyVar v) - = maybe t id (lookup v s) - applytv s (TyCon n ts) - = TyCon n (applytv s ts) - applytv _ t = t - - freetv (TyVar v@(TVar _)) = [v] - freetv (TyCon _ ts) = freetv ts - freetv _ = [] - -instance HasTV a => HasTV [a] where - applytv s = map (applytv s) - freetv = foldr (union . freetv) mempty - -instance HasTV a => HasTV (Maybe a) where - applytv s = fmap (applytv s) - freetv = maybe [] freetv - -instance (HasTV a, HasTV b) => HasTV (a,b) where -- defaults - -{- -instance (HasTV a, HasTV b, HasTV c) => HasTV (a,b,c) where - applytv s (z,x,y) = (applytv s z, applytv s x, applytv s y) - freetv (z,x,y) = freetv z `union` freetv x `union` freetv y - -instance (HasTV a, HasTV b) => HasTV (a,b) where - applytv s (x,y) = (applytv s x, applytv s y) - freetv (x,y) = freetv x `union` freetv y --} - -instance HasTV Id where - applytv s (Id n t) = Id n (applytv s t) - freetv (Id _ t) = freetv t - -instance HasTV a => HasTV (Param a) where -- defaults -instance HasTV a => HasTV (Exp a) where -- defaults -instance HasTV a => HasTV (Stmt a) where -- defaults - -instance HasTV (Pat Id) where - - -instance HasTV (Signature Id) where - applytv s = everywhere (mkT (applytv @Ty s)) - freetv sig = (everything (<>) (mkQ mempty (freetv @Ty))) sig \\ sigVars sig - renametv sig = do - subst <- foldM addRenaming mempty (sigVars sig) - pure (applytv subst sig, subst) - -{- -data FunDef a - = FunDef { - funSignature :: Signature a - , funDefBody :: [Stmt a] - } deriving (Eq, Ord, Show, Data, Typeable) --} - -instance HasTV (FunDef Id) where - freetv fd = (everything (<>) (mkQ mempty (freetv @Ty))) fd \\ sigVars (funSignature fd) - renametv fd = do - let sig = funSignature fd - subst <- foldM addRenaming mempty (sigVars sig) - let sig' = applytv subst sig - let body' = applytv subst (funDefBody fd) - pure(FunDef sig' body', subst) - -addRenaming :: TVSubst -> Tyvar -> SM TVSubst -addRenaming b a = do - fresh <- spNewName - pure ( (a |-> TyVar (TVar fresh)) <> b ) - --- TODO: refactor - make renametv return TVRenaming; turn rename* into class methods - -newtype TVRenaming - = TVR { unTVR :: [(Tyvar, Tyvar)] } deriving (Eq, Show) - -instance Pretty TVRenaming where - ppr = braces . commaSep . map go . unTVR - where - go (v,t) = ppr v <+> text "|->" <+> ppr t - -toTVS :: TVRenaming -> TVSubst -toTVS = TVSubst . map (fmap TyVar) . unTVR - -fromTVS :: TVSubst -> TVRenaming -fromTVS = TVR . map (fmap unTyVar) . unTVSubst where - unTyVar (TyVar x) = x - unTyVar t = error("fromTVS: " ++ pretty t ++ "is not a type variable") - -renameTV :: TVRenaming -> Tyvar -> Tyvar -renameTV (TVR r) v = fromMaybe v (lookup v r) - -renameTy :: TVRenaming -> Ty -> Ty -renameTy r = everywhere (mkT (renameTV r)) - -renameSubst :: TVRenaming -> TVSubst -> TVSubst -renameSubst r = TVSubst . map rename . unTVSubst where - rename (v, t) = (renameTV r v, renameTy r t) diff --git a/src/Solcore/Frontend/Pretty/ShortName.hs b/src/Solcore/Frontend/Pretty/ShortName.hs index d78aba78..9422c2ac 100644 --- a/src/Solcore/Frontend/Pretty/ShortName.hs +++ b/src/Solcore/Frontend/Pretty/ShortName.hs @@ -8,8 +8,6 @@ import Solcore.Frontend.TypeInference.Id import Solcore.Frontend.Pretty.SolcorePretty(pretty) import Common.Pretty -prettys :: Pretty a => [a] -> String -prettys = render . brackets . commaSep . map ppr class Pretty a => HasShortName a where shortName :: a -> String diff --git a/test/examples/spec/051negBool.solc b/test/examples/spec/051negBool.solc index 26319342..22888353 100644 --- a/test/examples/spec/051negBool.solc +++ b/test/examples/spec/051negBool.solc @@ -1,4 +1,4 @@ - +forall a. class a : Neg { function neg(x:a) -> a; } @@ -7,7 +7,7 @@ data B = F | T; instance B : Neg { - function neg (x : B) { + function neg (x : B) -> B { match x { | F => return T; | T => return F; @@ -18,12 +18,12 @@ instance B : Neg { contract NegBool { - function fromB(b) { + function fromB(b:B) -> word { match b { | F => return 0; | T => return 1; } } - function main() { return fromB(Neg.neg(F)); } + function main() -> word { return fromB(Neg.neg(F)); } } diff --git a/test/examples/spec/052negPair.solc b/test/examples/spec/052negPair.solc index 3d3542d0..194bc9e6 100644 --- a/test/examples/spec/052negPair.solc +++ b/test/examples/spec/052negPair.solc @@ -1,4 +1,4 @@ - +forall a. class a : Neg { function neg(x:a) -> a; } @@ -7,7 +7,7 @@ data B = F | T; data Pair(a,b) = Pair(a,b); instance B : Neg { - function neg (x : B) { + function neg (x : B) -> B { match x { | F => return T; | T => return F; @@ -28,8 +28,9 @@ function snd(p) { } -instance (a:Neg,b:Neg) => Pair(a,b):Neg { - function neg(p) { +forall a b. +a:Neg,b:Neg => instance Pair(a,b):Neg { + function neg(p:Pair(a,b)) -> Pair(a,b) { return Pair(Neg.neg (fst(p)), Neg.neg(snd (p))); } } @@ -45,19 +46,19 @@ instance (a:Neg,b:Neg) => Pair(a,b):Neg { */ contract NegPair { - function bnot(x) { + function bnot(x:B) -> B { match x { | T => return F; | F => return T; } } - function fromB(b) { + function fromB(b:B) -> word { match b { | F => return 0; | T => return 1; } } - function main() { return fromB(fst(Neg.neg(Pair(F,T)))); } + function main() -> word { return fromB(fst(Neg.neg(Pair(F,T)))); } }