diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 85f4afd..450072e 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -29,7 +29,10 @@ library Language.Granule.Codegen.Compile Language.Granule.Codegen.MarkGlobals other-modules: - Language.Granule.Codegen.Builtins + Language.Granule.Codegen.Builtins.Builtins + Language.Granule.Codegen.Builtins.Extras + Language.Granule.Codegen.Builtins.FloatArray + Language.Granule.Codegen.Builtins.Shared Language.Granule.Codegen.Emit.EmitableDef Language.Granule.Codegen.Emit.EmitBuiltins Language.Granule.Codegen.Emit.EmitterState @@ -44,6 +47,7 @@ library Language.Granule.Codegen.Emit.Names Language.Granule.Codegen.Emit.Primitives Language.Granule.Codegen.Emit.Types + Language.Granule.Codegen.RewriteAST Paths_granule_compiler hs-source-dirs: src diff --git a/src/Language/Granule/Codegen/Builtins.hs b/src/Language/Granule/Codegen/Builtins.hs deleted file mode 100644 index c48ea5a..0000000 --- a/src/Language/Granule/Codegen/Builtins.hs +++ /dev/null @@ -1,63 +0,0 @@ -{-# LANGUAGE RankNTypes #-} -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} - -module Language.Granule.Codegen.Builtins where - -import LLVM.AST (Operand) -import LLVM.AST.Type as IR -import LLVM.IRBuilder (MonadIRBuilder, sdiv) -import LLVM.IRBuilder.Instruction (zext) -import LLVM.IRBuilder.Module (MonadModuleBuilder) -import Language.Granule.Syntax.Identifiers -import Language.Granule.Syntax.Type as Gr - -data Builtin = Builtin { - builtinId :: String, - builtinArgTys :: [Gr.Type], - builtinRetTy :: Gr.Type, - builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} - -mkFunType :: [Gr.Type] -> Gr.Type -> Gr.Type -mkFunType args ret = foldr (FunTy Nothing Nothing) ret args - -builtins :: [Builtin] -builtins = [charToIntDef, divDef] - -builtinIds :: [Id] -builtinIds = map (mkId . builtinId) builtins - --- charToInt :: Char -> Int -charToIntDef :: Builtin -charToIntDef = - Builtin "charToInt" args ret impl - where - args = [TyCon (Id "Char" "Char")] - ret = TyCon (Id "Int" "Int") - impl [x] = zext x i32 - --- div :: Int -> Int -> Int -divDef :: Builtin -divDef = - Builtin "div" args ret impl - where - args = [TyCon (Id "Int" "Int"), TyCon (Id "Int" "Int")] - ret = TyCon (Id "Int" "Int") - impl [x, y] = sdiv x y - -structTy :: IR.Type -structTy = StructureType False [i32, ptr IR.double] - -tyInt :: Gr.Type -tyInt = TyCon (Id "Int" "Int") - -tyFloat :: Gr.Type -tyFloat = TyCon (Id "Float" "Float") - -tyChar :: Gr.Type -tyChar = TyCon (Id "Char" "Char") - -tyPair :: (Gr.Type, Gr.Type) -> Gr.Type -tyPair (l, r) = TyApp (TyApp (TyCon (Id "," ",")) l) r - -tyFloatArray :: Gr.Type -tyFloatArray = TyApp (TyCon (Id "FloatArray" "FloatArray")) (TyVar (Id "id" "id")) diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs new file mode 100644 index 0000000..f4aba6e --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -0,0 +1,24 @@ +module Language.Granule.Codegen.Builtins.Builtins where + +import Language.Granule.Codegen.Builtins.Extras +import Language.Granule.Codegen.Builtins.FloatArray +import Language.Granule.Codegen.Builtins.Shared +import Language.Granule.Syntax.Identifiers (Id, mkId) + +builtins :: [Builtin] +builtins = + [ charToIntDef, + divDef, + newFloatArrayIDef, + readFloatArrayIDef, + writeFloatArrayIDef, + lengthFloatArrayIDef, + newFloatArrayDef, + readFloatArrayDef, + writeFloatArrayDef, + lengthFloatArrayDef, + deleteFloatArrayDef + ] + +builtinIds :: [Id] +builtinIds = map (mkId . builtinId) builtins diff --git a/src/Language/Granule/Codegen/Builtins/Extras.hs b/src/Language/Granule/Codegen/Builtins/Extras.hs new file mode 100644 index 0000000..63e96b8 --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/Extras.hs @@ -0,0 +1,27 @@ +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module Language.Granule.Codegen.Builtins.Extras where + +import LLVM.AST.Type as IR +import LLVM.IRBuilder.Instruction +import Language.Granule.Codegen.Builtins.Shared +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Type as Gr + +-- charToInt :: Char -> Int +charToIntDef :: Builtin +charToIntDef = + Builtin "charToInt" args ret impl + where + args = [TyCon (Id "Char" "Char")] + ret = TyCon (Id "Int" "Int") + impl [x] = zext x i32 + +-- div :: Int -> Int -> Int +divDef :: Builtin +divDef = + Builtin "div" args ret impl + where + args = [TyCon (Id "Int" "Int"), TyCon (Id "Int" "Int")] + ret = TyCon (Id "Int" "Int") + impl [x, y] = sdiv x y diff --git a/src/Language/Granule/Codegen/Builtins/FloatArray.hs b/src/Language/Granule/Codegen/Builtins/FloatArray.hs new file mode 100644 index 0000000..9a803e4 --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/FloatArray.hs @@ -0,0 +1,133 @@ +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module Language.Granule.Codegen.Builtins.FloatArray where + +import qualified LLVM.AST.Constant as C +import LLVM.AST.IntegerPredicate +import LLVM.AST.Operand (Operand (ConstantOperand)) +import LLVM.AST.Type as IR +import LLVM.IRBuilder (emitBlockStart) +import LLVM.IRBuilder.Constant as C +import LLVM.IRBuilder.Instruction +import LLVM.IRBuilder.Module (MonadModuleBuilder) +import LLVM.IRBuilder.Monad (MonadIRBuilder, freshName) +import Language.Granule.Codegen.Builtins.Shared +import Language.Granule.Codegen.Emit.Primitives (trap) +import Prelude hiding (or) + +-- Mutable FloatArray builtins +newFloatArrayDef, readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, deleteFloatArrayDef :: Builtin +newFloatArrayDef = + Builtin "newFloatArray" [tyInt] tyFloatArray impl + where + impl [len] = newFloatArray len +readFloatArrayDef = + Builtin "readFloatArray" [tyFloatArray, tyInt] (tyPair (tyFloat, tyFloatArray)) impl + where + impl [arrPtr, idx] = withBoundsCheck arrPtr idx $ readFloatArray arrPtr idx +writeFloatArrayDef = + Builtin "writeFloatArray" [tyFloatArray, tyInt, tyFloat] tyFloatArray impl + where + impl [arrPtr, idx, val] = withBoundsCheck arrPtr idx $ do + dataPtr <- readStruct arrPtr 1 + writeData dataPtr idx val + return arrPtr +lengthFloatArrayDef = + Builtin "lengthFloatArray" [tyFloatArray] (tyPair (tyInt, tyFloatArray)) impl + where + impl [arrPtr] = lengthFloatArray arrPtr +deleteFloatArrayDef = + Builtin "deleteFloatArray" [tyFloatArray] tyUnit impl + where + impl [arrPtr] = do + dataPtr <- readStruct arrPtr 1 + _ <- free dataPtr + _ <- free arrPtr + -- return unit (need to check) + return $ ConstantOperand (C.Struct Nothing False []) + +-- Immutable FloatArray builtins +newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef :: Builtin +newFloatArrayIDef = + Builtin "newFloatArrayI" [tyInt] tyFloatArray impl + where + impl [len] = newFloatArray len +readFloatArrayIDef = + Builtin "readFloatArrayI" [tyFloatArray, tyInt] (tyPair (tyFloat, tyFloatArray)) impl + where + impl [arrPtr, idx] = withBoundsCheck arrPtr idx $ readFloatArray arrPtr idx +writeFloatArrayIDef = + Builtin "writeFloatArrayI" [tyFloatArray, tyInt, tyFloat] tyFloatArray impl + where + impl [arrPtr, idx, val] = withBoundsCheck arrPtr idx $ do + len <- readStruct arrPtr 0 + dataPtr <- readStruct arrPtr 1 + + -- len * double precision + size <- mul len (int32 8) + -- malloc wants i64 + size' <- sext size i64 + newDataPtr <- allocate size' IR.double + + -- copy existing data to new array + _ <- copy newDataPtr dataPtr size + + -- write value at index + writeData newDataPtr idx val + + -- return a new array struct + makeArrayStruct IR.double len newDataPtr +lengthFloatArrayIDef = + Builtin "lengthFloatArrayI" [tyFloatArray] (tyPair (tyInt, tyFloatArray)) impl + where + impl [arrPtr] = lengthFloatArray arrPtr + +-- arrayStruct specialisation +floatArrayStruct :: IR.Type +floatArrayStruct = arrayStruct IR.double + +-- creates a new float array of length +newFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +newFloatArray len = do + -- len * double precision + size <- mul len (int32 8) + -- malloc wants i64 + size' <- sext size i64 + dataPtr <- allocate size' IR.double + makeArrayStruct IR.double len dataPtr + +-- read the value at index of float array +readFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> m Operand +readFloatArray arrPtr idx = do + dataPtr <- readStruct arrPtr 1 + value <- readData dataPtr idx + makePair (IR.double, value) (ptr floatArrayStruct, arrPtr) + +-- read the length of float array +lengthFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +lengthFloatArray arrPtr = do + len <- readStruct arrPtr 0 + makePair (i32, len) (ptr floatArrayStruct, arrPtr) + +withBoundsCheck :: + (MonadIRBuilder m, MonadModuleBuilder m) => + Operand -> + Operand -> + m Operand -> + m Operand +withBoundsCheck arrPtr idx continuation = do + len <- readStruct arrPtr 0 + ltZero <- icmp SLT idx (int32 0) + gteLen <- icmp SGE idx len + outOfBounds <- or ltZero gteLen + + abort <- freshName "out_of_bounds" + continue <- freshName "in_bounds" + + condBr outOfBounds abort continue + + emitBlockStart abort + trap + + emitBlockStart continue + continuation diff --git a/src/Language/Granule/Codegen/Builtins/Shared.hs b/src/Language/Granule/Codegen/Builtins/Shared.hs new file mode 100644 index 0000000..29c52c5 --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/Shared.hs @@ -0,0 +1,106 @@ +{-# LANGUAGE RankNTypes #-} +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module Language.Granule.Codegen.Builtins.Shared where + +import LLVM.AST +import LLVM.AST.Type as IR +import LLVM.IRBuilder.Constant (bit, int32) +import LLVM.IRBuilder.Instruction (bitcast, call, gep, load, store, insertValue) +import LLVM.IRBuilder.Module +import LLVM.IRBuilder.Monad +import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) +import Language.Granule.Codegen.Emit.Primitives as P +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Type as Gr +import qualified LLVM.AST.Constant as C + +data Builtin = Builtin { + builtinId :: String, + builtinArgTys :: [Gr.Type], + builtinRetTy :: Gr.Type, + builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} + +-- LLVM helpers + +allocate :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> IR.Type -> m Operand +allocate len ty = do + pointer <- call (ConstantOperand P.malloc) [(len, [])] + bitcast pointer (ptr ty) + +allocateStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => IR.Type -> m Operand +allocateStruct ty = allocate (ConstantOperand $ sizeOf ty) ty + +copy :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> Operand -> m Operand +copy dst src len = do + dst' <- bitcast dst (ptr i8) + src' <- bitcast src (ptr i8) + call (ConstantOperand P.memcpy) [(dst', []), (src', []), (len, []), (bit 0, [])] + +free :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +free memPtr = do + memPtr' <- bitcast memPtr (ptr i8) + call (ConstantOperand P.free) [(memPtr', [])] + +makePair :: (MonadIRBuilder m, MonadModuleBuilder m) => (IR.Type, Operand) -> (IR.Type, Operand) -> m Operand +makePair (leftTy, leftVal) (rightTy, rightVal) = do + let pairTy = StructureType False [leftTy, rightTy] + let pair = ConstantOperand $ C.Undef pairTy + pair' <- insertValue pair leftVal [0] + insertValue pair' rightVal [1] + +writeStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Int -> Operand -> m () +writeStruct struct index value = do + field <- gep struct [int32 0, int32 $ fromIntegral index] + store field 0 value + +readStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Int -> m Operand +readStruct struct index = do + field <- gep struct [int32 0, int32 $ fromIntegral index] + load field 0 + +-- Arrays + +arrayStruct :: IR.Type -> IR.Type +arrayStruct ty = StructureType False [i32, ptr ty] + +-- creates a struct for len and array data +makeArrayStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => IR.Type -> Operand -> Operand -> m Operand +makeArrayStruct ty len dataPtr = do + arrPtr <- allocateStruct (arrayStruct ty) + writeStruct arrPtr 0 len + writeStruct arrPtr 1 dataPtr + return arrPtr + +writeData :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> Operand -> m () +writeData dataPtr index value = do + valuePtr <- gep dataPtr [index] + store valuePtr 0 value + +readData :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> m Operand +readData dataPtr index = do + valuePtr <- gep dataPtr [index] + load valuePtr 0 + +-- Granule types + +mkFunType :: [Gr.Type] -> Gr.Type -> Gr.Type +mkFunType args ret = foldr (FunTy Nothing Nothing) ret args + +tyInt :: Gr.Type +tyInt = TyCon (Id "Int" "Int") + +tyFloat :: Gr.Type +tyFloat = TyCon (Id "Float" "Float") + +tyChar :: Gr.Type +tyChar = TyCon (Id "Char" "Char") + +tyUnit :: Gr.Type +tyUnit = TyCon (Id "()" "()") + +tyPair :: (Gr.Type, Gr.Type) -> Gr.Type +tyPair (l, r) = TyApp (TyApp (TyCon (Id "," ",")) l) r + +tyFloatArray :: Gr.Type +tyFloatArray = TyApp (TyCon (Id "FloatArray" "FloatArray")) (TyVar (Id "id" "id")) diff --git a/src/Language/Granule/Codegen/Compile.hs b/src/Language/Granule/Codegen/Compile.hs index b747701..892c241 100644 --- a/src/Language/Granule/Codegen/Compile.hs +++ b/src/Language/Granule/Codegen/Compile.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE ImplicitParams #-} {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} module Language.Granule.Codegen.Compile where @@ -9,13 +8,14 @@ import Language.Granule.Codegen.TopsortDefinitions import Language.Granule.Codegen.ConvertClosures import Language.Granule.Codegen.Emit.EmitLLVM import Language.Granule.Codegen.MarkGlobals +import Language.Granule.Codegen.RewriteAST + import qualified LLVM.AST as IR ---import Language.Granule.Syntax.Pretty ---import Debug.Trace compile :: String -> AST () Type -> Either String IR.Module compile moduleName typedAST = - let normalised = {-trace (show typedAST)-} (normaliseDefinitions typedAST) + let rewritten = rewriteAST typedAST + normalised = normaliseDefinitions rewritten markedGlobals = markGlobals normalised (Ok topsorted) = topologicallySortDefinitions markedGlobals closureFree = convertClosures topsorted diff --git a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs index 678760c..728990d 100644 --- a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs +++ b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs @@ -11,12 +11,12 @@ import LLVM.IRBuilder.Constant (int32) import LLVM.IRBuilder.Instruction import LLVM.IRBuilder.Module import LLVM.IRBuilder.Monad (IRBuilderT) -import Language.Granule.Codegen.Builtins +import Language.Granule.Codegen.Builtins.Builtins +import Language.Granule.Codegen.Builtins.Shared import Language.Granule.Codegen.Emit.LLVMHelpers import Language.Granule.Codegen.Emit.LowerClosure (mallocEnvironment) import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTypeForClosure, llvmTypeForFunction) --- TODO: only emit builtins as required emitBuiltins :: (MonadModuleBuilder m) => m [Operand] emitBuiltins = mapM emitBuiltin builtins diff --git a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs index 8d324bf..fae0c94 100644 --- a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs +++ b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs @@ -45,6 +45,8 @@ emitLLVM moduleName (ClosureFreeAST dataDecls functionDefs valueDefs) = _ <- extern (mkName "malloc") [i64] (ptr i8) _ <- extern (mkName "abort") [] void _ <- externVarArgs (mkName "printf") [ptr i8] i32 + _ <- extern (mkName "llvm.memcpy.p0.p0.i32") [ptr i8, ptr i8, i32, i1] void + _ <- extern (mkName "free") [ptr i8] void _ <- emitBuiltins let mainTy = findMainReturnType valueDefs _ <- emitMainOut mainTy diff --git a/src/Language/Granule/Codegen/Emit/LowerType.hs b/src/Language/Granule/Codegen/Emit/LowerType.hs index de18e1a..b8ec373 100644 --- a/src/Language/Granule/Codegen/Emit/LowerType.hs +++ b/src/Language/Granule/Codegen/Emit/LowerType.hs @@ -44,6 +44,8 @@ llvmType (FunTy _ _ from to) = llvmTypeForClosure $ llvmTypeForFunction (llvmType from) (llvmType to) llvmType (TyApp (TyApp (TyCon (MkId ",")) left) right) = StructureType False [llvmType left, llvmType right] +llvmType (TyApp (TyCon (MkId "FloatArray")) _) = + ptr $ StructureType False [i32, ptr double] llvmType (TyCon (MkId "()")) = StructureType False [] llvmType (TyCon (MkId "Int")) = i32 @@ -52,6 +54,8 @@ llvmType (TyCon (MkId "Char")) = i8 llvmType (TyCon (MkId "Handle")) = i8 llvmType (TyCon (MkId "Bool")) = i1 llvmType (Box coeffect ty) = llvmType ty +llvmType (TyExists _ _ ty) = llvmType ty +llvmType (Borrow _ ty) = llvmType ty llvmType ty = error $ "Cannot lower the type " ++ show ty llvmTopLevelType :: GrType -> IrType diff --git a/src/Language/Granule/Codegen/Emit/MainOut.hs b/src/Language/Granule/Codegen/Emit/MainOut.hs index c4a43fa..e79efb3 100644 --- a/src/Language/Granule/Codegen/Emit/MainOut.hs +++ b/src/Language/Granule/Codegen/Emit/MainOut.hs @@ -62,5 +62,7 @@ fmtStrForTy x = (TyCon (Id "Float" _)) -> "%.6f" (TyApp (TyApp (TyCon (Id "," _)) leftTy) rightTy) -> "(" ++ fmtStrForTy leftTy ++ ", " ++ fmtStrForTy rightTy ++ ")" + (TyApp (TyCon (Id "FloatArray" _)) _) -> "" (TyCon (Id "()" _)) -> "()" - _ -> error "Unsupported" + (TyExists _ _ (Borrow _ ty)) -> "*" ++ fmtStrForTy ty + _ -> error ("Unsupported Main type: " ++ show x) diff --git a/src/Language/Granule/Codegen/Emit/Primitives.hs b/src/Language/Granule/Codegen/Emit/Primitives.hs index c3ed98e..c5870ea 100644 --- a/src/Language/Granule/Codegen/Emit/Primitives.hs +++ b/src/Language/Granule/Codegen/Emit/Primitives.hs @@ -4,7 +4,7 @@ import LLVM.IRBuilder.Instruction import LLVM.AST (mkName, Operand(..)) import LLVM.AST.Constant (Constant, Constant(..)) -import LLVM.AST.Type (i8, i32, i64, ptr, void, Type(..)) +import LLVM.AST.Type (i1, i8, i32, i64, ptr, void, Type(..)) import LLVM.IRBuilder (MonadModuleBuilder) malloc :: Constant @@ -25,3 +25,13 @@ printf :: Constant printf = GlobalReference functionType name where name = mkName "printf" functionType = ptr (FunctionType i32 [ptr i8] True) + +memcpy :: Constant +memcpy = GlobalReference functionType name + where name = mkName "llvm.memcpy.p0.p0.i32" + functionType = ptr (FunctionType void [ptr i8, ptr i8, i32, i1] False) + +free :: Constant +free = GlobalReference functionType name + where name = mkName "free" + functionType = ptr (FunctionType void [ptr i8] False) diff --git a/src/Language/Granule/Codegen/MarkGlobals.hs b/src/Language/Granule/Codegen/MarkGlobals.hs index c9ecae3..64dc0bc 100644 --- a/src/Language/Granule/Codegen/MarkGlobals.hs +++ b/src/Language/Granule/Codegen/MarkGlobals.hs @@ -5,7 +5,7 @@ import Language.Granule.Syntax.Type import Language.Granule.Syntax.Identifiers import Language.Granule.Syntax.Pretty import Data.Bifunctor.Foldable -import Language.Granule.Codegen.Builtins (builtinIds) +import Language.Granule.Codegen.Builtins.Builtins (builtinIds) data GlobalMarker = GlobalVar Type Id diff --git a/src/Language/Granule/Codegen/RewriteAST.hs b/src/Language/Granule/Codegen/RewriteAST.hs new file mode 100644 index 0000000..32f710b --- /dev/null +++ b/src/Language/Granule/Codegen/RewriteAST.hs @@ -0,0 +1,134 @@ +module Language.Granule.Codegen.RewriteAST where + +import Data.Bifunctor (bimap) +import Data.List (mapAccumL) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Identifiers (Id) +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type + +-- Rewrite Unpack ASTs into App Abs ASTs which our +-- compiler already knows how to handle. WIP. + +rewriteAST :: AST ev Type -> AST ev Type +rewriteAST ast = ast {definitions = map rewriteDef (definitions ast)} + where + rewriteDef def = def {defEquations = rewriteEquationList (defEquations def)} + rewriteEquationList eqs = eqs {equations = map rewriteEquation (equations eqs)} + rewriteEquation eq = eq {equationBody = rewriteExpr (equationBody eq)} + +rewriteExpr :: Expr ev Type -> Expr ev Type +rewriteExpr (Unpack s retTy b tyVar var e1 e2) = + let e1' = e1 + e1Ty = exprTy e1' + e2' = e2 + absTy = FunTy Nothing Nothing e1Ty retTy + in fixTypes (App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2')) e1') + where + fixTypes expr = snd $ substExpr emptyEnv expr +rewriteExpr exp = exp + +exprTy :: Expr ev Type -> Type +exprTy (App _ ty _ _ _) = ty +exprTy (Val _ ty _ _) = ty +exprTy (Binop _ ty _ _ _ _) = ty +exprTy (LetDiamond _ ty _ _ _ _ _) = ty +exprTy (Case _ ty _ _ _) = ty +exprTy (Hole _ ty _ _ _) = ty +exprTy (AppTy _ ty _ _ _) = ty +exprTy (TryCatch _ ty _ _ _ _ _ _) = ty +exprTy (Unpack _ ty _ _ _ _ _) = ty + +-- `let (x, y) = ` inside of an Unpack seems to leave TyVars in the AST, and these +-- are not already handled by the compiler. Here we find the correct types and substitute +-- the TyVars. WIP. + +-- val var -> Type, type var -> Type +type Env = (Map.Map Id Type, Map.Map Id Type) + +emptyEnv :: Env +emptyEnv = (Map.empty, Map.empty) + +insertEnv :: Env -> Either Id Id -> Type -> Env +insertEnv (vals, tys) (Left id) ty = (Map.insert id ty vals, tys) +insertEnv (vals, tys) (Right id) ty = (vals, Map.insert id ty tys) + +lookupEnv :: Env -> Either Id Id -> Maybe Type +lookupEnv (vals, tys) (Left id) = Map.lookup id vals +lookupEnv (vals, tys) (Right id) = Map.lookup id tys + +substExpr :: Env -> Expr ev Type -> (Env, Expr ev Type) +substExpr env (App s ty b e1 e2) = + let (env', e2') = substExpr env e2 + (env'', e1') = substExpr env' e1 + ty' = substTy env ty + in (env'', App s ty' b e1' e2') +substExpr env (Val s ty b v) = + let (env', v') = substVal env v + ty' = substTy env' ty + in (env', Val s ty' b v') +substExpr env exp = error "TODO expr" + +substVal :: Env -> Value ev Type -> (Env, Value ev Type) +substVal env (Var (TyVar id) var) = + -- see if we already have it + case lookupEnv env (Right id) of + Just ty -> (env, Var ty var) + Nothing -> + -- see if the value variable has it + case lookupEnv env (Left var) of + -- and update + Just ty -> (insertEnv env (Right id) ty, Var ty var) + -- we wont always win + Nothing -> (env, Var (TyVar id) var) +substVal env (Var ty var) = (insertEnv env (Left var) ty, Var ty var) +substVal env (Abs ty p mt e) = + let (env', p') = substPat env p + (env'', e') = substExpr env' e + ty' = substTy env'' ty + in (env'', Abs ty' p' mt e') +substVal env (Constr ty id vals) = + let (env', vals') = mapAccumL substVal env vals + ty' = substTy env' ty + in (env', Constr ty' id vals') +substVal env (NumInt v) = (env, NumInt v) +substVal env (NumFloat v) = (env, NumFloat v) +substVal env (Promote t v) = (env, Promote t v) +substVal env val = error "TODO val" + +substPat :: Env -> Pattern Type -> (Env, Pattern Type) +substPat env (PVar s (TyVar id) b var) = + case lookupEnv env (Right id) of + Just ty -> (env, PVar s ty b var) + Nothing -> + case lookupEnv env (Left var) of + Just ty -> (insertEnv env (Right id) ty, PVar s ty b var) + Nothing -> (env, PVar s (TyVar id) b var) +substPat env (PVar s ty b var) = (insertEnv env (Left var) ty, PVar s ty b var) +substPat env (PConstr s ty b id ids ps) = + let (env', ps') = mapAccumL substPat env ps + ty' = substTy env' ty + in (env', PConstr s ty' b id ids ps') +substPat env p = error "TODO pat" + +substTy :: Env -> Type -> Type +substTy env (TyVar id) = fromMaybe (TyVar id) (lookupEnv env (Right id)) +substTy env (Type i) = Type i +substTy env (FunTy id mc arg ret) = FunTy id mc (substTy env arg) (substTy env ret) +substTy env (TyCon id) = TyCon id +substTy env (Box c t) = substTy env t +substTy env (Diamond e t) = Diamond (substTy env e) (substTy env t) +substTy env (Star g t) = substTy env t +substTy env (Borrow p t) = substTy env t +substTy env (TyApp t1 t2) = TyApp (substTy env t1) (substTy env t2) +substTy env (TyGrade mt i) = TyGrade mt i +substTy env (TyInfix op t1 t2) = TyInfix op (substTy env t1) (substTy env t2) +substTy env (TySet p ts) = TySet p (map (substTy env) ts) +substTy env (TyCase t tps) = TyCase (substTy env t) (map (bimap (substTy env) (substTy env)) tps) +substTy env (TySig t k) = TySig (substTy env t) (substTy env k) +substTy env (TyExists id k t) = substTy env t +substTy env (TyForall id k t) = substTy env t +substTy env t = t diff --git a/tests/golden/positive/array.golden b/tests/golden/positive/array.golden new file mode 100644 index 0000000..e563993 --- /dev/null +++ b/tests/golden/positive/array.golden @@ -0,0 +1 @@ +(40.000000, (100.000000, (2, ))) diff --git a/tests/golden/positive/array.gr b/tests/golden/positive/array.gr new file mode 100644 index 0000000..0a1a580 --- /dev/null +++ b/tests/golden/positive/array.gr @@ -0,0 +1,10 @@ +write2 : forall {id : Name} . FloatArray id -> (Float, Float) -> FloatArray id +write2 arr (x, y) = writeFloatArrayI (writeFloatArrayI arr 0 x) 1 y + +main : forall {id : Name} . (Float, (Float, (Int, FloatArray id))) +main = + let arr = write2 (newFloatArrayI 2) (40.0, 100.0); + (x, arr') = readFloatArrayI arr 0; + (y, arr'') = readFloatArrayI arr' 1; + (l, arr''') = lengthFloatArrayI arr'' in + (x, (y, (l, arr'''))) diff --git a/tests/golden/positive/boxes.golden b/tests/golden/positive/boxes.golden new file mode 100644 index 0000000..10e0f08 --- /dev/null +++ b/tests/golden/positive/boxes.golden @@ -0,0 +1 @@ +42.000000 diff --git a/tests/golden/positive/boxes.gr b/tests/golden/positive/boxes.gr new file mode 100644 index 0000000..e1b6d9b --- /dev/null +++ b/tests/golden/positive/boxes.gr @@ -0,0 +1,10 @@ +deleteFloatArrayI : forall {id : Name} . (FloatArray id) [] -> () +deleteFloatArrayI [a] = () + +main : Float +main = + let [arr] = [(newFloatArrayI 1)] in + let [arr'] = [writeFloatArrayI arr 0 42.0] in + let [(x, arr'')] = [readFloatArrayI arr' 0] in + let () = deleteFloatArrayI [arr''] in + x diff --git a/tests/golden/positive/unpack.golden b/tests/golden/positive/unpack.golden new file mode 100644 index 0000000..bd55d74 --- /dev/null +++ b/tests/golden/positive/unpack.golden @@ -0,0 +1 @@ +(42.000000, 100.000000) diff --git a/tests/golden/positive/unpack.gr b/tests/golden/positive/unpack.gr new file mode 100644 index 0000000..db537f9 --- /dev/null +++ b/tests/golden/positive/unpack.gr @@ -0,0 +1,9 @@ +main : (Float, Float) +main = + unpack = newFloatArray 2 in + let arr' = writeFloatArray arr 0 42.0 in + let arr'' = writeFloatArray arr' 1 100.0 in + let (x, arr''') = readFloatArray arr'' 0 in + let (y, arr'''') = readFloatArray arr''' 1 in + let () = deleteFloatArray arr'''' in + (x, y)