From 9a7739e7ac8195ab572001a53cd3c50037a92a8a Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Fri, 10 Jan 2025 12:23:45 +0000 Subject: [PATCH 01/25] implement immutable array builtins --- src/Language/Granule/Codegen/Builtins.hs | 128 +++++++++++++++++- src/Language/Granule/Codegen/Emit/EmitLLVM.hs | 1 + .../Granule/Codegen/Emit/LowerType.hs | 2 + src/Language/Granule/Codegen/Emit/MainOut.hs | 1 + .../Granule/Codegen/Emit/Primitives.hs | 7 +- 5 files changed, 133 insertions(+), 6 deletions(-) diff --git a/src/Language/Granule/Codegen/Builtins.hs b/src/Language/Granule/Codegen/Builtins.hs index c48ea5a..fb4789b 100644 --- a/src/Language/Granule/Codegen/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins.hs @@ -3,11 +3,15 @@ module Language.Granule.Codegen.Builtins where -import LLVM.AST (Operand) +import LLVM.AST +import qualified LLVM.AST.Constant as C import LLVM.AST.Type as IR -import LLVM.IRBuilder (MonadIRBuilder, sdiv) -import LLVM.IRBuilder.Instruction (zext) -import LLVM.IRBuilder.Module (MonadModuleBuilder) +import LLVM.IRBuilder.Constant as C +import LLVM.IRBuilder.Instruction +import LLVM.IRBuilder.Module +import LLVM.IRBuilder.Monad +import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) +import Language.Granule.Codegen.Emit.Primitives (malloc, memcpy) import Language.Granule.Syntax.Identifiers import Language.Granule.Syntax.Type as Gr @@ -21,7 +25,8 @@ mkFunType :: [Gr.Type] -> Gr.Type -> Gr.Type mkFunType args ret = foldr (FunTy Nothing Nothing) ret args builtins :: [Builtin] -builtins = [charToIntDef, divDef] +builtins = [charToIntDef, divDef, newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef ] + builtinIds :: [Id] builtinIds = map (mkId . builtinId) builtins @@ -44,6 +49,119 @@ divDef = ret = TyCon (Id "Int" "Int") impl [x, y] = sdiv x y +-- newFloatArrayI :: Int -> FloatArray id +newFloatArrayIDef :: Builtin +newFloatArrayIDef = + Builtin "newFloatArrayI" args ret impl + where + args = [tyInt] + ret = tyFloatArray + impl [len] = do + -- arrays are a struct {int32 len, double* data} - on heap - use stack? + arrPtr <- call (ConstantOperand malloc) [(ConstantOperand $ sizeOf structTy, [])] + arrPtr' <- bitcast arrPtr (ptr structTy) + + -- not 100% on the size and if we need to do anything for alignment, but it works + dataSize <- mul len (int32 8) + dataPtr <- call (ConstantOperand malloc) [(dataSize, [])] + + lenField <- gep arrPtr' [int32 0, int32 0] + store lenField 0 len + + dataField <- gep arrPtr' [int32 0, int32 1] + store dataField 0 dataPtr + + return arrPtr + +-- readFloatArrayI :: (FloatArray id) -> Int -> (Float, FloatArray id) +readFloatArrayIDef :: Builtin +readFloatArrayIDef = + Builtin "readFloatArrayI" args ret impl + where + args = [tyFloatArray, tyInt] + ret = tyPair (tyFloat, tyFloatArray) + impl [arrPtr, idx] = do + -- arr -> data -> idx -> val + + arrPtr' <- bitcast arrPtr (ptr structTy) + + dataField <- gep arrPtr' [int32 0, int32 1] + dataPtr <- load dataField 0 + + valuePtr <- gep dataPtr [idx] + value <- load valuePtr 0 + + -- pair return (float, floatArray) on stack + let pairTy = StructureType False [IR.double, ptr structTy] + let pair = ConstantOperand $ C.Undef pairTy + pair' <- insertValue pair value [0] + insertValue pair' arrPtr [1] + + + +-- writeFloatArrayI :: (FloatArray id) -> Int -> Float -> FloatArray id +writeFloatArrayIDef :: Builtin +writeFloatArrayIDef = + Builtin "writeFloatArrayI" args ret impl + where + args = [tyFloatArray, tyInt, tyFloat] + ret = tyFloatArray + impl [arrPtr, idx, val] = do + arrPtr' <- bitcast arrPtr (ptr structTy) + + lenField <- gep arrPtr' [int32 0, int32 0] + len <- load lenField 0 + + dataField <- gep arrPtr' [int32 0, int32 1] + dataPtr <- load dataField 0 + + -- need to create a new array as in newFloatArrayI + newArrPtr <- call (ConstantOperand malloc) [(ConstantOperand $ sizeOf structTy, [])] + newArrPtr' <- bitcast newArrPtr (ptr structTy) + + dataSize <- mul len (int32 8) + newDataPtr <- call (ConstantOperand malloc) [(dataSize, [])] + newDataPtr' <- bitcast newDataPtr (ptr IR.double) + + -- copy the existing data to new array + _ <- call (ConstantOperand memcpy) + [ (newDataPtr, []) + , (dataPtr, []) + , (dataSize, []) + , (bit 0, []) + ] + + -- write the val to new copy + valuePtr <- gep newDataPtr' [idx] + store valuePtr 0 val + + newLenField <- gep newArrPtr' [int32 0, int32 0] + store newLenField 0 len + + newDataField <- gep newArrPtr' [int32 0, int32 1] + store newDataField 0 newDataPtr + + return newArrPtr + +-- lengthFloatArrayI :: (FloatArray id) -> (Int -> FloatArray id) +lengthFloatArrayIDef :: Builtin +lengthFloatArrayIDef = + Builtin "lengthFloatArrayI" args ret impl + where + args = [tyFloatArray] + ret = tyPair (tyInt, tyFloatArray) + impl [arrPtr] = do + arrPtr' <- bitcast arrPtr (ptr structTy) + + lenField <- gep arrPtr' [int32 0, int32 0] + len <- load lenField 0 + + -- pair return (int, floatArray) on stack + let pairTy = StructureType False [i32, ptr structTy] + let pair = ConstantOperand $ C.Undef pairTy + pair' <- insertValue pair len [0] + insertValue pair' arrPtr [1] + structTy :: IR.Type structTy = StructureType False [i32, ptr IR.double] diff --git a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs index d65c138..9efda17 100644 --- a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs +++ b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs @@ -45,6 +45,7 @@ emitLLVM moduleName (ClosureFreeAST dataDecls functionDefs valueDefs) = _ <- extern (mkName "malloc") [i64] (ptr i8) _ <- extern (mkName "abort") [] void _ <- externVarArgs (mkName "printf") [ptr i8] i32 + _ <- extern (mkName "llvm.memcpy.p0.p0.i32") [ptr i8, ptr i8, i32, i1] void _ <- emitBuiltins let mainTy = findMainReturnType valueDefs _ <- emitMainOut mainTy diff --git a/src/Language/Granule/Codegen/Emit/LowerType.hs b/src/Language/Granule/Codegen/Emit/LowerType.hs index 5f4ef47..d2c1ae3 100644 --- a/src/Language/Granule/Codegen/Emit/LowerType.hs +++ b/src/Language/Granule/Codegen/Emit/LowerType.hs @@ -44,6 +44,8 @@ llvmType (FunTy _ _ from to) = llvmTypeForClosure $ llvmTypeForFunction (llvmType from) (llvmType to) llvmType (TyApp (TyApp (TyCon (MkId ",")) left) right) = StructureType False [llvmType left, llvmType right] +llvmType (TyApp (TyCon (MkId "FloatArray")) _) = + ptr $ StructureType False [i32, ptr double] llvmType (TyCon (MkId "Int")) = i32 llvmType (TyCon (MkId "Float")) = double llvmType (TyCon (MkId "Char")) = i8 diff --git a/src/Language/Granule/Codegen/Emit/MainOut.hs b/src/Language/Granule/Codegen/Emit/MainOut.hs index 5cde7a3..1d1dcf9 100644 --- a/src/Language/Granule/Codegen/Emit/MainOut.hs +++ b/src/Language/Granule/Codegen/Emit/MainOut.hs @@ -62,4 +62,5 @@ fmtStrForTy x = (TyCon (Id "Float" _)) -> "%.6f" (TyApp (TyApp (TyCon (Id "," _)) leftTy) rightTy) -> "(" ++ fmtStrForTy leftTy ++ ", " ++ fmtStrForTy rightTy ++ ")" + (TyApp (TyCon (Id "FloatArray" _)) _) -> "" _ -> error "Unsupported" diff --git a/src/Language/Granule/Codegen/Emit/Primitives.hs b/src/Language/Granule/Codegen/Emit/Primitives.hs index c3ed98e..12e2aba 100644 --- a/src/Language/Granule/Codegen/Emit/Primitives.hs +++ b/src/Language/Granule/Codegen/Emit/Primitives.hs @@ -4,7 +4,7 @@ import LLVM.IRBuilder.Instruction import LLVM.AST (mkName, Operand(..)) import LLVM.AST.Constant (Constant, Constant(..)) -import LLVM.AST.Type (i8, i32, i64, ptr, void, Type(..)) +import LLVM.AST.Type (i1, i8, i32, i64, ptr, void, Type(..)) import LLVM.IRBuilder (MonadModuleBuilder) malloc :: Constant @@ -25,3 +25,8 @@ printf :: Constant printf = GlobalReference functionType name where name = mkName "printf" functionType = ptr (FunctionType i32 [ptr i8] True) + +memcpy :: Constant +memcpy = GlobalReference functionType name + where name = mkName "llvm.memcpy.p0.p0.i32" + functionType = ptr (FunctionType void [ptr i8, ptr i8, i32, i1] False) From 5163fec5b1deb5d517a331d7945b76b508aa5bb1 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Fri, 10 Jan 2025 12:25:30 +0000 Subject: [PATCH 02/25] capturing unwanted expressions. this breaks pairs but fixes arrays --- .../Granule/Codegen/Emit/LowerExpression.hs | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/Language/Granule/Codegen/Emit/LowerExpression.hs b/src/Language/Granule/Codegen/Emit/LowerExpression.hs index 0c5a8e1..804019c 100644 --- a/src/Language/Granule/Codegen/Emit/LowerExpression.hs +++ b/src/Language/Granule/Codegen/Emit/LowerExpression.hs @@ -49,16 +49,17 @@ emitExpr :: (MonadState EmitterState m, MonadModuleBuilder m, MonadIRBuilder m, => Maybe Operand -> ExprF (Either GlobalMarker ClosureMarker) Type (EmitableExpr, m Operand) (EmitableValue, m Operand) -> m Operand -emitExpr environment (AppF _ (FunTy _ _ _ (TyApp (TyApp (TyCon (Id "," _)) _) _)) _ _ (_, emitArg)) = emitArg - -emitExpr environment (AppF _ (TyApp (TyApp (TyCon (Id "," _)) leftTy) rightTy) _ (ExprFix2 (AppF {}), emitFunction) (_, emitArg)) = - do - leftVal <- emitFunction - rightVal <- emitArg - let pairTy = IRType.StructureType False [llvmType leftTy, llvmType rightTy] - let pair = IR.ConstantOperand $ C.Undef pairTy - pair' <- insertValue pair leftVal [0] - insertValue pair' rightVal [1] +-- TODO - maybe this should be handled elsewhere? doesn't play nice with our builtins +-- emitExpr environment (AppF _ (FunTy _ _ _ (TyApp (TyApp (TyCon (Id "," _)) _) _)) _ _ (_, emitArg)) = emitArg + +-- emitExpr environment (AppF _ (TyApp (TyApp (TyCon (Id "," _)) leftTy) rightTy) _ (ExprFix2 (AppF {}), emitFunction) (_, emitArg)) = +-- do +-- leftVal <- emitFunction +-- rightVal <- emitArg +-- let pairTy = IRType.StructureType False [llvmType leftTy, llvmType rightTy] +-- let pair = IR.ConstantOperand $ C.Undef pairTy +-- pair' <- insertValue pair leftVal [0] +-- insertValue pair' rightVal [1] emitExpr environment (AppF _ ty _ (_, emitFunction) (_, emitArg)) = do From 37bd8347857a8cfc56878f22ff65403660c71e86 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Thu, 23 Jan 2025 12:30:08 +0000 Subject: [PATCH 03/25] cleanup comments --- src/Language/Granule/Codegen/Builtins.hs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Language/Granule/Codegen/Builtins.hs b/src/Language/Granule/Codegen/Builtins.hs index fb4789b..f1457fc 100644 --- a/src/Language/Granule/Codegen/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins.hs @@ -57,11 +57,9 @@ newFloatArrayIDef = args = [tyInt] ret = tyFloatArray impl [len] = do - -- arrays are a struct {int32 len, double* data} - on heap - use stack? arrPtr <- call (ConstantOperand malloc) [(ConstantOperand $ sizeOf structTy, [])] arrPtr' <- bitcast arrPtr (ptr structTy) - - -- not 100% on the size and if we need to do anything for alignment, but it works + -- length * double precision 8 bytes dataSize <- mul len (int32 8) dataPtr <- call (ConstantOperand malloc) [(dataSize, [])] From d88459a296695e056e748044333adf1a3fcd3015 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Thu, 23 Jan 2025 12:32:34 +0000 Subject: [PATCH 04/25] hide some mess --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 48f0299..fa56521 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .stack-work/ stack.yaml.lock +.tmp/ From bcd37cf6cb4efbd796e87ea9d98a2e1455cc8406 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sun, 2 Mar 2025 21:58:39 +0000 Subject: [PATCH 05/25] array test --- tests/golden/positive/array.golden | 1 + tests/golden/positive/array.gr | 10 ++++++++++ 2 files changed, 11 insertions(+) create mode 100644 tests/golden/positive/array.golden create mode 100644 tests/golden/positive/array.gr diff --git a/tests/golden/positive/array.golden b/tests/golden/positive/array.golden new file mode 100644 index 0000000..e563993 --- /dev/null +++ b/tests/golden/positive/array.golden @@ -0,0 +1 @@ +(40.000000, (100.000000, (2, ))) diff --git a/tests/golden/positive/array.gr b/tests/golden/positive/array.gr new file mode 100644 index 0000000..0a1a580 --- /dev/null +++ b/tests/golden/positive/array.gr @@ -0,0 +1,10 @@ +write2 : forall {id : Name} . FloatArray id -> (Float, Float) -> FloatArray id +write2 arr (x, y) = writeFloatArrayI (writeFloatArrayI arr 0 x) 1 y + +main : forall {id : Name} . (Float, (Float, (Int, FloatArray id))) +main = + let arr = write2 (newFloatArrayI 2) (40.0, 100.0); + (x, arr') = readFloatArrayI arr 0; + (y, arr'') = readFloatArrayI arr' 1; + (l, arr''') = lengthFloatArrayI arr'' in + (x, (y, (l, arr'''))) From 79fd00d33b5005d6dcf15e494adf47bc798572aa Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sun, 2 Mar 2025 22:26:43 +0000 Subject: [PATCH 06/25] move builtins to their own folder, update imports --- granule-compiler.cabal | 5 +- .../Granule/Codegen/Builtins/Builtins.hs | 12 ++++ .../Granule/Codegen/Builtins/Extras.hs | 27 ++++++++ .../ImmutableArray.hs} | 64 +------------------ .../Granule/Codegen/Builtins/Shared.hs | 41 ++++++++++++ .../Granule/Codegen/Emit/EmitBuiltins.hs | 4 +- src/Language/Granule/Codegen/MarkGlobals.hs | 2 +- 7 files changed, 90 insertions(+), 65 deletions(-) create mode 100644 src/Language/Granule/Codegen/Builtins/Builtins.hs create mode 100644 src/Language/Granule/Codegen/Builtins/Extras.hs rename src/Language/Granule/Codegen/{Builtins.hs => Builtins/ImmutableArray.hs} (69%) create mode 100644 src/Language/Granule/Codegen/Builtins/Shared.hs diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 600c5f8..170de4e 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -29,7 +29,10 @@ library Language.Granule.Codegen.Compile Language.Granule.Codegen.MarkGlobals other-modules: - Language.Granule.Codegen.Builtins + Language.Granule.Codegen.Builtins.Builtins + Language.Granule.Codegen.Builtins.Extras + Language.Granule.Codegen.Builtins.ImmutableArray + Language.Granule.Codegen.Builtins.Shared Language.Granule.Codegen.Emit.EmitableDef Language.Granule.Codegen.Emit.EmitBuiltins Language.Granule.Codegen.Emit.EmitterState diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs new file mode 100644 index 0000000..0fc0f3f --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -0,0 +1,12 @@ +module Language.Granule.Codegen.Builtins.Builtins where + +import Language.Granule.Codegen.Builtins.Shared +import Language.Granule.Codegen.Builtins.ImmutableArray +import Language.Granule.Codegen.Builtins.Extras +import Language.Granule.Syntax.Identifiers (Id, mkId) + +builtins :: [Builtin] +builtins = [charToIntDef, divDef, newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef ] + +builtinIds :: [Id] +builtinIds = map (mkId . builtinId) builtins diff --git a/src/Language/Granule/Codegen/Builtins/Extras.hs b/src/Language/Granule/Codegen/Builtins/Extras.hs new file mode 100644 index 0000000..63e96b8 --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/Extras.hs @@ -0,0 +1,27 @@ +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module Language.Granule.Codegen.Builtins.Extras where + +import LLVM.AST.Type as IR +import LLVM.IRBuilder.Instruction +import Language.Granule.Codegen.Builtins.Shared +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Type as Gr + +-- charToInt :: Char -> Int +charToIntDef :: Builtin +charToIntDef = + Builtin "charToInt" args ret impl + where + args = [TyCon (Id "Char" "Char")] + ret = TyCon (Id "Int" "Int") + impl [x] = zext x i32 + +-- div :: Int -> Int -> Int +divDef :: Builtin +divDef = + Builtin "div" args ret impl + where + args = [TyCon (Id "Int" "Int"), TyCon (Id "Int" "Int")] + ret = TyCon (Id "Int" "Int") + impl [x, y] = sdiv x y diff --git a/src/Language/Granule/Codegen/Builtins.hs b/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs similarity index 69% rename from src/Language/Granule/Codegen/Builtins.hs rename to src/Language/Granule/Codegen/Builtins/ImmutableArray.hs index f1457fc..9d177ea 100644 --- a/src/Language/Granule/Codegen/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs @@ -1,53 +1,15 @@ -{-# LANGUAGE RankNTypes #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} -module Language.Granule.Codegen.Builtins where +module Language.Granule.Codegen.Builtins.ImmutableArray where import LLVM.AST -import qualified LLVM.AST.Constant as C import LLVM.AST.Type as IR +import qualified LLVM.AST.Constant as C import LLVM.IRBuilder.Constant as C import LLVM.IRBuilder.Instruction -import LLVM.IRBuilder.Module -import LLVM.IRBuilder.Monad +import Language.Granule.Codegen.Builtins.Shared import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) import Language.Granule.Codegen.Emit.Primitives (malloc, memcpy) -import Language.Granule.Syntax.Identifiers -import Language.Granule.Syntax.Type as Gr - -data Builtin = Builtin { - builtinId :: String, - builtinArgTys :: [Gr.Type], - builtinRetTy :: Gr.Type, - builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} - -mkFunType :: [Gr.Type] -> Gr.Type -> Gr.Type -mkFunType args ret = foldr (FunTy Nothing Nothing) ret args - -builtins :: [Builtin] -builtins = [charToIntDef, divDef, newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef ] - - -builtinIds :: [Id] -builtinIds = map (mkId . builtinId) builtins - --- charToInt :: Char -> Int -charToIntDef :: Builtin -charToIntDef = - Builtin "charToInt" args ret impl - where - args = [TyCon (Id "Char" "Char")] - ret = TyCon (Id "Int" "Int") - impl [x] = zext x i32 - --- div :: Int -> Int -> Int -divDef :: Builtin -divDef = - Builtin "div" args ret impl - where - args = [TyCon (Id "Int" "Int"), TyCon (Id "Int" "Int")] - ret = TyCon (Id "Int" "Int") - impl [x, y] = sdiv x y -- newFloatArrayI :: Int -> FloatArray id newFloatArrayIDef :: Builtin @@ -89,7 +51,6 @@ readFloatArrayIDef = valuePtr <- gep dataPtr [idx] value <- load valuePtr 0 - -- pair return (float, floatArray) on stack let pairTy = StructureType False [IR.double, ptr structTy] let pair = ConstantOperand $ C.Undef pairTy pair' <- insertValue pair value [0] @@ -154,26 +115,7 @@ lengthFloatArrayIDef = lenField <- gep arrPtr' [int32 0, int32 0] len <- load lenField 0 - -- pair return (int, floatArray) on stack let pairTy = StructureType False [i32, ptr structTy] let pair = ConstantOperand $ C.Undef pairTy pair' <- insertValue pair len [0] insertValue pair' arrPtr [1] - -structTy :: IR.Type -structTy = StructureType False [i32, ptr IR.double] - -tyInt :: Gr.Type -tyInt = TyCon (Id "Int" "Int") - -tyFloat :: Gr.Type -tyFloat = TyCon (Id "Float" "Float") - -tyChar :: Gr.Type -tyChar = TyCon (Id "Char" "Char") - -tyPair :: (Gr.Type, Gr.Type) -> Gr.Type -tyPair (l, r) = TyApp (TyApp (TyCon (Id "," ",")) l) r - -tyFloatArray :: Gr.Type -tyFloatArray = TyApp (TyCon (Id "FloatArray" "FloatArray")) (TyVar (Id "id" "id")) diff --git a/src/Language/Granule/Codegen/Builtins/Shared.hs b/src/Language/Granule/Codegen/Builtins/Shared.hs new file mode 100644 index 0000000..095dac1 --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/Shared.hs @@ -0,0 +1,41 @@ +{-# LANGUAGE RankNTypes #-} +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module Language.Granule.Codegen.Builtins.Shared where + +import LLVM.AST +import LLVM.AST.Type as IR +import LLVM.IRBuilder.Module +import LLVM.IRBuilder.Monad + +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Type as Gr + +data Builtin = Builtin { + builtinId :: String, + builtinArgTys :: [Gr.Type], + builtinRetTy :: Gr.Type, + builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} + +-- helpers + +mkFunType :: [Gr.Type] -> Gr.Type -> Gr.Type +mkFunType args ret = foldr (FunTy Nothing Nothing) ret args + +structTy :: IR.Type +structTy = StructureType False [i32, ptr IR.double] + +tyInt :: Gr.Type +tyInt = TyCon (Id "Int" "Int") + +tyFloat :: Gr.Type +tyFloat = TyCon (Id "Float" "Float") + +tyChar :: Gr.Type +tyChar = TyCon (Id "Char" "Char") + +tyPair :: (Gr.Type, Gr.Type) -> Gr.Type +tyPair (l, r) = TyApp (TyApp (TyCon (Id "," ",")) l) r + +tyFloatArray :: Gr.Type +tyFloatArray = TyApp (TyCon (Id "FloatArray" "FloatArray")) (TyVar (Id "id" "id")) diff --git a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs index 678760c..728990d 100644 --- a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs +++ b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs @@ -11,12 +11,12 @@ import LLVM.IRBuilder.Constant (int32) import LLVM.IRBuilder.Instruction import LLVM.IRBuilder.Module import LLVM.IRBuilder.Monad (IRBuilderT) -import Language.Granule.Codegen.Builtins +import Language.Granule.Codegen.Builtins.Builtins +import Language.Granule.Codegen.Builtins.Shared import Language.Granule.Codegen.Emit.LLVMHelpers import Language.Granule.Codegen.Emit.LowerClosure (mallocEnvironment) import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTypeForClosure, llvmTypeForFunction) --- TODO: only emit builtins as required emitBuiltins :: (MonadModuleBuilder m) => m [Operand] emitBuiltins = mapM emitBuiltin builtins diff --git a/src/Language/Granule/Codegen/MarkGlobals.hs b/src/Language/Granule/Codegen/MarkGlobals.hs index 7886825..8d27303 100644 --- a/src/Language/Granule/Codegen/MarkGlobals.hs +++ b/src/Language/Granule/Codegen/MarkGlobals.hs @@ -5,7 +5,7 @@ import Language.Granule.Syntax.Type import Language.Granule.Syntax.Identifiers import Language.Granule.Syntax.Pretty import Data.Bifunctor.Foldable -import Language.Granule.Codegen.Builtins (builtinIds) +import Language.Granule.Codegen.Builtins.Builtins (builtinIds) data GlobalMarker = GlobalVar Type Id From 2389bf58c0094340819c683b6d92268d37944179 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sun, 2 Mar 2025 22:42:44 +0000 Subject: [PATCH 07/25] initial llvm implementation --- granule-compiler.cabal | 1 + .../Granule/Codegen/Builtins/MutableArray.hs | 115 ++++++++++++++++++ .../Granule/Codegen/Builtins/Shared.hs | 3 + src/Language/Granule/Codegen/Emit/EmitLLVM.hs | 1 + .../Granule/Codegen/Emit/Primitives.hs | 5 + 5 files changed, 125 insertions(+) create mode 100644 src/Language/Granule/Codegen/Builtins/MutableArray.hs diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 170de4e..b85192b 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -32,6 +32,7 @@ library Language.Granule.Codegen.Builtins.Builtins Language.Granule.Codegen.Builtins.Extras Language.Granule.Codegen.Builtins.ImmutableArray + Language.Granule.Codegen.Builtins.MutableArray Language.Granule.Codegen.Builtins.Shared Language.Granule.Codegen.Emit.EmitableDef Language.Granule.Codegen.Emit.EmitBuiltins diff --git a/src/Language/Granule/Codegen/Builtins/MutableArray.hs b/src/Language/Granule/Codegen/Builtins/MutableArray.hs new file mode 100644 index 0000000..4f6726c --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/MutableArray.hs @@ -0,0 +1,115 @@ +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module Language.Granule.Codegen.Builtins.MutableArray where + +import LLVM.AST +import LLVM.AST.Type as IR +import qualified LLVM.AST.Constant as C +import LLVM.IRBuilder.Constant as C +import LLVM.IRBuilder.Instruction +import Language.Granule.Codegen.Builtins.Shared +import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) +import Language.Granule.Codegen.Emit.Primitives (malloc, free) + +-- newFloatArray :: Int -> FloatArray id +newFloatArrayDef :: Builtin +newFloatArrayDef = + Builtin "newFloatArray" args ret impl + where + args = [tyInt] + ret = tyFloatArray + impl [len] = do + arrPtr <- call (ConstantOperand malloc) [(ConstantOperand $ sizeOf structTy, [])] + arrPtr' <- bitcast arrPtr (ptr structTy) + -- length * double precision 8 bytes + dataSize <- mul len (int32 8) + dataPtr <- call (ConstantOperand malloc) [(dataSize, [])] + + lenField <- gep arrPtr' [int32 0, int32 0] + store lenField 0 len + + dataField <- gep arrPtr' [int32 0, int32 1] + store dataField 0 dataPtr + + return arrPtr + +-- readFloatArray :: (FloatArray id) -> Int -> (Float, FloatArray id) +readFloatArrayDef :: Builtin +readFloatArrayDef = + Builtin "readFloatArray" args ret impl + where + args = [tyFloatArray, tyInt] + ret = tyPair (tyFloat, tyFloatArray) + impl [arrPtr, idx] = do + -- arr -> data -> idx -> val + + arrPtr' <- bitcast arrPtr (ptr structTy) + + dataField <- gep arrPtr' [int32 0, int32 1] + dataPtr <- load dataField 0 + + valuePtr <- gep dataPtr [idx] + value <- load valuePtr 0 + + let pairTy = StructureType False [IR.double, ptr structTy] + let pair = ConstantOperand $ C.Undef pairTy + pair' <- insertValue pair value [0] + insertValue pair' arrPtr [1] + + + +-- writeFloatArray :: (FloatArray id) -> Int -> Float -> FloatArray id +writeFloatArrayDef :: Builtin +writeFloatArrayDef = + Builtin "writeFloatArray" args ret impl + where + args = [tyFloatArray, tyInt, tyFloat] + ret = tyFloatArray + impl [arrPtr, idx, val] = do + arrPtr' <- bitcast arrPtr (ptr structTy) + + dataField <- gep arrPtr' [int32 0, int32 1] + dataPtr <- load dataField 0 + + valuePtr <- gep arrPtr' [idx] + store valuePtr 0 val + + return arrPtr + +-- lengthFloatArray :: (FloatArray id) -> (Int -> FloatArray id) +lengthFloatArrayDef :: Builtin +lengthFloatArrayDef = + Builtin "lengthFloatArray" args ret impl + where + args = [tyFloatArray] + ret = tyPair (tyInt, tyFloatArray) + impl [arrPtr] = do + arrPtr' <- bitcast arrPtr (ptr structTy) + + lenField <- gep arrPtr' [int32 0, int32 0] + len <- load lenField 0 + + let pairTy = StructureType False [i32, ptr structTy] + let pair = ConstantOperand $ C.Undef pairTy + pair' <- insertValue pair len [0] + insertValue pair' arrPtr [1] + +deleteFloatArrayDef :: Builtin +deleteFloatArrayDef = + Builtin "deleteFloatArray" args ret impl + where + args = [tyFloatArray] + ret = tyUnit + impl [arrPtr] = do + arrPtr' <- bitcast arrPtr (ptr structTy) + dataField <- gep arrPtr' [int32 0, int32 1] + dataPtr <- load dataField 0 + + -- free the data + _ <- call (ConstantOperand free) [(dataPtr, [])] + + -- free the array pair + _ <- call (ConstantOperand free) [(arrPtr, [])] + + -- return unit (need to check) + return $ ConstantOperand (C.Struct Nothing False []) diff --git a/src/Language/Granule/Codegen/Builtins/Shared.hs b/src/Language/Granule/Codegen/Builtins/Shared.hs index 095dac1..72a3aa8 100644 --- a/src/Language/Granule/Codegen/Builtins/Shared.hs +++ b/src/Language/Granule/Codegen/Builtins/Shared.hs @@ -34,6 +34,9 @@ tyFloat = TyCon (Id "Float" "Float") tyChar :: Gr.Type tyChar = TyCon (Id "Char" "Char") +tyUnit :: Gr.Type +tyUnit = TyCon (Id "()" "()") + tyPair :: (Gr.Type, Gr.Type) -> Gr.Type tyPair (l, r) = TyApp (TyApp (TyCon (Id "," ",")) l) r diff --git a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs index c84a93b..546f0ae 100644 --- a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs +++ b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs @@ -47,6 +47,7 @@ emitLLVM moduleName (ClosureFreeAST dataDecls functionDefs valueDefs) = _ <- extern (mkName "abort") [] void _ <- externVarArgs (mkName "printf") [ptr i8] i32 _ <- extern (mkName "llvm.memcpy.p0.p0.i32") [ptr i8, ptr i8, i32, i1] void + _ <- extern (mkName "free") [ptr i8] void _ <- emitBuiltins let mainTy = findMainReturnType valueDefs _ <- emitMainOut mainTy diff --git a/src/Language/Granule/Codegen/Emit/Primitives.hs b/src/Language/Granule/Codegen/Emit/Primitives.hs index 12e2aba..c5870ea 100644 --- a/src/Language/Granule/Codegen/Emit/Primitives.hs +++ b/src/Language/Granule/Codegen/Emit/Primitives.hs @@ -30,3 +30,8 @@ memcpy :: Constant memcpy = GlobalReference functionType name where name = mkName "llvm.memcpy.p0.p0.i32" functionType = ptr (FunctionType void [ptr i8, ptr i8, i32, i1] False) + +free :: Constant +free = GlobalReference functionType name + where name = mkName "free" + functionType = ptr (FunctionType void [ptr i8] False) From fa6259cd2a9ed39f489bb1a73db2cb078064f4d4 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sun, 2 Mar 2025 23:46:20 +0000 Subject: [PATCH 08/25] lower unique types, expose builtins --- src/Language/Granule/Codegen/Builtins/Builtins.hs | 7 ++++++- src/Language/Granule/Codegen/Emit/LowerType.hs | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs index 0fc0f3f..13eb691 100644 --- a/src/Language/Granule/Codegen/Builtins/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -2,11 +2,16 @@ module Language.Granule.Codegen.Builtins.Builtins where import Language.Granule.Codegen.Builtins.Shared import Language.Granule.Codegen.Builtins.ImmutableArray +import Language.Granule.Codegen.Builtins.MutableArray import Language.Granule.Codegen.Builtins.Extras import Language.Granule.Syntax.Identifiers (Id, mkId) builtins :: [Builtin] -builtins = [charToIntDef, divDef, newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef ] +builtins = [ + charToIntDef, divDef, + newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef, + newFloatArrayDef, readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, deleteFloatArrayDef + ] builtinIds :: [Id] builtinIds = map (mkId . builtinId) builtins diff --git a/src/Language/Granule/Codegen/Emit/LowerType.hs b/src/Language/Granule/Codegen/Emit/LowerType.hs index d2c1ae3..b7d362d 100644 --- a/src/Language/Granule/Codegen/Emit/LowerType.hs +++ b/src/Language/Granule/Codegen/Emit/LowerType.hs @@ -52,6 +52,7 @@ llvmType (TyCon (MkId "Char")) = i8 llvmType (TyCon (MkId "Handle")) = i8 llvmType (TyCon (MkId "Bool")) = i1 llvmType (Box coeffect ty) = llvmType ty +llvmType (TyExists _ _ (Borrow _ ty)) = llvmType ty llvmType ty = error $ "Cannot lower the type " ++ show ty llvmTopLevelType :: GrType -> IrType From a832df6a66599976f0528fbda61c2c8444dd7513 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Wed, 5 Mar 2025 01:37:29 +0000 Subject: [PATCH 09/25] tests mutable arrays and Unpack / Box / Unit --- tests/golden/positive/boxes.golden | 1 + tests/golden/positive/boxes.gr | 10 ++++++++++ tests/golden/positive/unpack.golden | 1 + tests/golden/positive/unpack.gr | 9 +++++++++ 4 files changed, 21 insertions(+) create mode 100644 tests/golden/positive/boxes.golden create mode 100644 tests/golden/positive/boxes.gr create mode 100644 tests/golden/positive/unpack.golden create mode 100644 tests/golden/positive/unpack.gr diff --git a/tests/golden/positive/boxes.golden b/tests/golden/positive/boxes.golden new file mode 100644 index 0000000..10e0f08 --- /dev/null +++ b/tests/golden/positive/boxes.golden @@ -0,0 +1 @@ +42.000000 diff --git a/tests/golden/positive/boxes.gr b/tests/golden/positive/boxes.gr new file mode 100644 index 0000000..e1b6d9b --- /dev/null +++ b/tests/golden/positive/boxes.gr @@ -0,0 +1,10 @@ +deleteFloatArrayI : forall {id : Name} . (FloatArray id) [] -> () +deleteFloatArrayI [a] = () + +main : Float +main = + let [arr] = [(newFloatArrayI 1)] in + let [arr'] = [writeFloatArrayI arr 0 42.0] in + let [(x, arr'')] = [readFloatArrayI arr' 0] in + let () = deleteFloatArrayI [arr''] in + x diff --git a/tests/golden/positive/unpack.golden b/tests/golden/positive/unpack.golden new file mode 100644 index 0000000..bd55d74 --- /dev/null +++ b/tests/golden/positive/unpack.golden @@ -0,0 +1 @@ +(42.000000, 100.000000) diff --git a/tests/golden/positive/unpack.gr b/tests/golden/positive/unpack.gr new file mode 100644 index 0000000..59ffe55 --- /dev/null +++ b/tests/golden/positive/unpack.gr @@ -0,0 +1,9 @@ +main : (Float, Float) +main = + unpack = newFloatArray 1 in + let arr' = writeFloatArray arr 0 42.0 in + let arr'' = writeFloatArray arr' 1 100.0 in + let (x, arr''') = readFloatArray arr'' 0 in + let (y, arr'''') = readFloatArray arr''' 1 in + let () = deleteFloatArray arr'''' in + (x, y) From 7c40c12a944a1fea32ab28b6adf21d5f3fa4a2e6 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Wed, 5 Mar 2025 01:38:29 +0000 Subject: [PATCH 10/25] fix wrong pointer --- src/Language/Granule/Codegen/Builtins/MutableArray.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Granule/Codegen/Builtins/MutableArray.hs b/src/Language/Granule/Codegen/Builtins/MutableArray.hs index 4f6726c..3eb0ad8 100644 --- a/src/Language/Granule/Codegen/Builtins/MutableArray.hs +++ b/src/Language/Granule/Codegen/Builtins/MutableArray.hs @@ -71,7 +71,7 @@ writeFloatArrayDef = dataField <- gep arrPtr' [int32 0, int32 1] dataPtr <- load dataField 0 - valuePtr <- gep arrPtr' [idx] + valuePtr <- gep dataPtr [idx] store valuePtr 0 val return arrPtr From 20eceb0d60acd42309592cb795e05cf3dc9ae062 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Wed, 5 Mar 2025 01:39:22 +0000 Subject: [PATCH 11/25] types for uniqueness --- src/Language/Granule/Codegen/Emit/LowerType.hs | 3 ++- src/Language/Granule/Codegen/Emit/MainOut.hs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Granule/Codegen/Emit/LowerType.hs b/src/Language/Granule/Codegen/Emit/LowerType.hs index 85fe4cd..325097a 100644 --- a/src/Language/Granule/Codegen/Emit/LowerType.hs +++ b/src/Language/Granule/Codegen/Emit/LowerType.hs @@ -54,7 +54,8 @@ llvmType (TyCon (MkId "Char")) = i8 llvmType (TyCon (MkId "Handle")) = i8 llvmType (TyCon (MkId "Bool")) = i1 llvmType (Box coeffect ty) = llvmType ty -llvmType (TyExists _ _ (Borrow _ ty)) = llvmType ty +llvmType (TyExists _ _ ty) = llvmType ty +llvmType (Borrow (TyCon (MkId "Star")) ty) = llvmType ty llvmType ty = error $ "Cannot lower the type " ++ show ty llvmTopLevelType :: GrType -> IrType diff --git a/src/Language/Granule/Codegen/Emit/MainOut.hs b/src/Language/Granule/Codegen/Emit/MainOut.hs index 4a6ceb5..e79efb3 100644 --- a/src/Language/Granule/Codegen/Emit/MainOut.hs +++ b/src/Language/Granule/Codegen/Emit/MainOut.hs @@ -64,4 +64,5 @@ fmtStrForTy x = "(" ++ fmtStrForTy leftTy ++ ", " ++ fmtStrForTy rightTy ++ ")" (TyApp (TyCon (Id "FloatArray" _)) _) -> "" (TyCon (Id "()" _)) -> "()" - _ -> error "Unsupported" + (TyExists _ _ (Borrow _ ty)) -> "*" ++ fmtStrForTy ty + _ -> error ("Unsupported Main type: " ++ show x) From 9e5bce3e7a5b84723ab032c6d628156c2bdf4133 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Wed, 5 Mar 2025 01:39:35 +0000 Subject: [PATCH 12/25] AST processing --- granule-compiler.cabal | 4 + src/Language/Granule/Codegen/Compile.hs | 10 +- src/Language/Granule/Codegen/PrintAST.hs | 116 +++++++++++++++++++++ src/Language/Granule/Codegen/RetypeAST.hs | 108 +++++++++++++++++++ src/Language/Granule/Codegen/RewriteAST.hs | 45 ++++++++ src/Language/Granule/Codegen/StripAST.hs | 110 +++++++++++++++++++ 6 files changed, 389 insertions(+), 4 deletions(-) create mode 100644 src/Language/Granule/Codegen/PrintAST.hs create mode 100644 src/Language/Granule/Codegen/RetypeAST.hs create mode 100644 src/Language/Granule/Codegen/RewriteAST.hs create mode 100644 src/Language/Granule/Codegen/StripAST.hs diff --git a/granule-compiler.cabal b/granule-compiler.cabal index b85192b..7a9e701 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -48,6 +48,10 @@ library Language.Granule.Codegen.Emit.Names Language.Granule.Codegen.Emit.Primitives Language.Granule.Codegen.Emit.Types + Language.Granule.Codegen.PrintAST + Language.Granule.Codegen.RetypeAST + Language.Granule.Codegen.RewriteAST + Language.Granule.Codegen.StripAST Paths_granule_compiler hs-source-dirs: src diff --git a/src/Language/Granule/Codegen/Compile.hs b/src/Language/Granule/Codegen/Compile.hs index b747701..101a52d 100644 --- a/src/Language/Granule/Codegen/Compile.hs +++ b/src/Language/Granule/Codegen/Compile.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE ImplicitParams #-} {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} module Language.Granule.Codegen.Compile where @@ -9,13 +8,16 @@ import Language.Granule.Codegen.TopsortDefinitions import Language.Granule.Codegen.ConvertClosures import Language.Granule.Codegen.Emit.EmitLLVM import Language.Granule.Codegen.MarkGlobals +import Language.Granule.Codegen.RewriteAST +import Language.Granule.Codegen.StripAST + import qualified LLVM.AST as IR ---import Language.Granule.Syntax.Pretty ---import Debug.Trace compile :: String -> AST () Type -> Either String IR.Module compile moduleName typedAST = - let normalised = {-trace (show typedAST)-} (normaliseDefinitions typedAST) + let stripped = stripAST typedAST + rewritten = rewriteAST stripped + normalised = normaliseDefinitions rewritten markedGlobals = markGlobals normalised (Ok topsorted) = topologicallySortDefinitions markedGlobals closureFree = convertClosures topsorted diff --git a/src/Language/Granule/Codegen/PrintAST.hs b/src/Language/Granule/Codegen/PrintAST.hs new file mode 100644 index 0000000..c1041f5 --- /dev/null +++ b/src/Language/Granule/Codegen/PrintAST.hs @@ -0,0 +1,116 @@ +module Language.Granule.Codegen.PrintAST where + +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Identifiers (Id (Id)) +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type + +-- converts an AST to JSON for debugging & visualisation. WIP + +data JSON = Str String | Obj [(String, JSON)] | Arr [JSON] + +printAST :: AST ev Type -> String +printAST ast = toString (jAST ast) + +toString :: JSON -> String +toString (Str s) = "\"" ++ s ++ "\"" +toString (Obj pairs) = "{" ++ commas (map field pairs) ++ "}" +toString (Arr elems) = "[" ++ commas (map toString elems) ++ "]" + +field :: (String, JSON) -> String +field (key, value) = "\"" ++ key ++ "\":" ++ toString value + +commas :: [String] -> String +commas [] = "" +commas [x] = x +commas (x : xs) = x ++ "," ++ commas xs + +node :: String -> [(String, JSON)] -> JSON +node name fields = Obj [(name, Obj fields)] + +jAST :: AST ev Type -> JSON +jAST (AST _ defs _ _ _) = Arr (map jDef defs) + +jDef :: Def ev Type -> JSON +jDef (Def _ _ _ _ equations _) = jEqList equations + +jEqList :: EquationList ev Type -> JSON +jEqList (EquationList _ _ _ equations) = Arr (map jEq equations) + +jEq :: Equation ev Type -> JSON +jEq (Equation s i a b ps expr) = jExpr expr + +jExpr :: Expr ev Type -> JSON +jExpr (App _ ty _ fn arg) = node "App" [("type", jTy ty), ("fn", jExpr fn), ("arg", jExpr arg)] +jExpr (Binop {}) = Str "%BINOP%" +jExpr (LetDiamond {}) = Str "%LETDIAMOND%" +jExpr (Val _ ty _ (Promote ty' expr)) = jExpr expr +jExpr (Val _ ty _ val) = node "Val" [("type", jTy ty), ("val", jVal val)] +jExpr (Case {}) = Str "%CASE%" +jExpr (Hole {}) = Str "%HOLE%" +jExpr (AppTy {}) = Str "%APP_TY%" +jExpr (TryCatch {}) = Str "%TRY_CATCH%" +jExpr (Unpack _ ty _ _ id e1 e2) = node "Unpack" [("type", jTy ty), ("id", jId id), ("e1", jExpr e1), ("e2", jExpr e2)] + +jVal :: Value ev Type -> JSON +jVal (Var ty id) = node "Var" [("type", jTy ty), ("id", jId id)] +jVal (Abs ty pat mty expr) = node "Abs" [("type", jTy ty), ("pattern", jPat pat), ("expr", jExpr expr)] +jVal (Promote ty expr) = node "Promote" [("type", jTy ty), ("expr", jExpr expr)] +jVal (Pure ty expr) = node "Pure" [("type", jTy ty), ("expr", jExpr expr)] +jVal (Constr ty id vals) = node "Constr" [("type", jTy ty), ("id", jId id), ("vals", Arr (map jVal vals))] +jVal (NumInt n) = Str (show n) +jVal (NumFloat n) = Str (show n) +jVal (CharLiteral c) = Str (show c) +jVal (StringLiteral s) = Str (show s) +jVal (Ext {}) = Str "%EXT%" +jVal (Nec {}) = Str "%NEC%" +jVal (Pack {}) = Str "%PACK%" +jVal (TyAbs {}) = Str "%TY_ABS%" + +jPat :: Pattern Type -> JSON +jPat (PVar _ ty _ id) = node "PVar" [("type", jTy ty), ("id", jId id)] +jPat (PWild {}) = Str "%PWILD%" +jPat (PBox _ ty _ p) = node "PBox" [("ty", jTy ty), ("pat", jPat p)] +jPat (PInt {}) = Str "%PINT%" +jPat (PFloat {}) = Str "%PFLOAT%" +jPat (PConstr _ ty _ id _ pats) = node "PConstr" [("type", jTy ty), ("id", jId id), ("pats", Arr (map jPat pats))] + +jId :: Id -> JSON +jId id = Str (sId id) + +jTy :: Type -> JSON +jTy ty = Str (sTy ty) + +sId :: Id -> String +sId (Id _ id) = id + +paren :: String -> String +paren str = "(" ++ str ++ ")" + +named :: String -> String -> String +named name str = name ++ " " ++ paren str + +sTy :: Type -> String +sTy (Type {}) = "type" +sTy (FunTy id _ arg ret) = sTy arg ++ " -> " ++ sTy ret +sTy (TyCon id) = sId id +sTy (Box _ ty) = "[" ++ sTy ty ++ "]" +sTy (Diamond {}) = "%DIAMOND%" +sTy (Star {}) = "%STAR%" +sTy (Borrow (TyCon (Id "Star" "Star")) ty) = "*" ++ paren (sTy ty) +sTy (Borrow p ty) = "& " ++ paren (sTy p) ++ " " ++ paren (sTy ty) +sTy (TyVar id) = "TyVar (" ++ sId id ++ ")" +sTy (TyApp t1 t2) = "(" ++ sTy t1 ++ ") (" ++ sTy t2 ++ ")" +sTy (TyInt {}) = "%TY_INT%" +sTy (TyRational {}) = "%TY_RATIONAL%" +sTy (TyFraction {}) = "%TY_FRACTION%" +sTy (TyGrade {}) = "%TY_GRADE%" +sTy (TyInfix {}) = "%TY_INFIX%" +sTy (TySet {}) = "%TY_SET%" +sTy (TyCase {}) = "%TY_CASE%" +sTy (TySig {}) = "%TY_SIG%" +-- sTy (TyExists id kind ty) = "exists {" ++ sId id ++ " : " ++ sTy kind ++ "} . " ++ sTy ty +sTy (TyExists id kind ty) = sTy ty +sTy (TyForall id kind ty) = "forall {" ++ sId id ++ " : " ++ sTy kind ++ "} . " ++ sTy ty +sTy (TyName {}) = "%TY_NAME%" diff --git a/src/Language/Granule/Codegen/RetypeAST.hs b/src/Language/Granule/Codegen/RetypeAST.hs new file mode 100644 index 0000000..4f0b6f8 --- /dev/null +++ b/src/Language/Granule/Codegen/RetypeAST.hs @@ -0,0 +1,108 @@ +module Language.Granule.Codegen.RetypeAST where + +import Data.Bifunctor (bimap) +import Data.List (mapAccumL) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Identifiers (Id) +import Language.Granule.Syntax.Pattern (Pattern (..)) +import Language.Granule.Syntax.Type + +-- `let (x, y) = ` inside of an Unpack seems to leave TyVars in the AST, and these +-- are not already handled by the compiler. Here we find the correct types and substitute +-- the TyVars. WIP. + +-- val var -> Type, type var -> Type +type Env = (Map.Map Id Type, Map.Map Id Type) + +emptyEnv :: Env +emptyEnv = (Map.empty, Map.empty) + +insertEnv :: Env -> Either Id Id -> Type -> Env +insertEnv (vals, tys) (Left id) ty = (Map.insert id ty vals, tys) +insertEnv (vals, tys) (Right id) ty = (vals, Map.insert id ty tys) + +lookupEnv :: Env -> Either Id Id -> Maybe Type +lookupEnv (vals, tys) (Left id) = Map.lookup id vals +lookupEnv (vals, tys) (Right id) = Map.lookup id tys + +retypeAST :: AST ev Type -> AST ev Type +retypeAST ast = ast {definitions = map retypeDef (definitions ast)} + where + retypeDef def = def {defEquations = retypeEquationList (defEquations def)} + retypeEquationList eqs = eqs {equations = map retypeEquation (equations eqs)} + retypeEquation eq = eq {equationBody = snd (retypeExpr emptyEnv (equationBody eq))} + +retypeExpr :: Env -> Expr ev Type -> (Env, Expr ev Type) +retypeExpr env (App s ty b e1 e2) = + let (env', e2') = retypeExpr env e2 + (env'', e1') = retypeExpr env' e1 + ty' = subsTy env ty + in (env'', App s ty' b e1' e2') +retypeExpr env (Val s ty b v) = + let (env', v') = retypeVal env v + ty' = subsTy env' ty + in (env', Val s ty' b v') +retypeExpr env exp = error "TODO expr" + +retypeVal :: Env -> Value ev Type -> (Env, Value ev Type) +retypeVal env (Var (TyVar id) var) = + -- see if we already have it + case lookupEnv env (Right id) of + Just ty -> (env, Var ty var) + Nothing -> + -- see if the value variable has it + case lookupEnv env (Left var) of + -- and update + Just ty -> (insertEnv env (Right id) ty, Var ty var) + -- we wont always win + Nothing -> (env, Var (TyVar id) var) +retypeVal env (Var ty var) = (insertEnv env (Left var) ty, Var ty var) +retypeVal env (Abs ty p mt e) = + let (env', p') = retypePat env p + (env'', e') = retypeExpr env' e + ty' = subsTy env'' ty + in (env'', Abs ty' p' mt e') +retypeVal env (Constr ty id vals) = + let (env', vals') = mapAccumL retypeVal env vals + ty' = subsTy env' ty + in (env', Constr ty' id vals') +retypeVal env (NumInt v) = (env, NumInt v) +retypeVal env (NumFloat v) = (env, NumFloat v) +retypeVal env val = error "TODO val" + +retypePat :: Env -> Pattern Type -> (Env, Pattern Type) +retypePat env (PVar s (TyVar id) b var) = + case lookupEnv env (Right id) of + Just ty -> (env, PVar s ty b var) + Nothing -> + case lookupEnv env (Left var) of + Just ty -> (insertEnv env (Right id) ty, PVar s ty b var) + Nothing -> (env, PVar s (TyVar id) b var) +retypePat env (PVar s ty b var) = (insertEnv env (Left var) ty, PVar s ty b var) +retypePat env (PConstr s ty b id ids ps) = + let (env', ps') = mapAccumL retypePat env ps + ty' = subsTy env' ty + in (env', PConstr s ty' b id ids ps') +retypePat env p = error "TODO pat" + +subsTy :: Env -> Type -> Type +subsTy env (TyVar id) = fromMaybe (TyVar id) (lookupEnv env (Right id)) +subsTy env (Type i) = Type i +subsTy env (FunTy id mc arg ret) = FunTy id mc (subsTy env arg) (subsTy env ret) +subsTy env (TyCon id) = TyCon id +subsTy env (Box c t) = subsTy env t +subsTy env (Diamond e t) = Diamond (subsTy env e) (subsTy env t) +subsTy env (Star g t) = subsTy env t +subsTy env (Borrow p t) = subsTy env t +subsTy env (TyApp t1 t2) = TyApp (subsTy env t1) (subsTy env t2) +subsTy env (TyGrade mt i) = TyGrade mt i +subsTy env (TyInfix op t1 t2) = TyInfix op (subsTy env t1) (subsTy env t2) +subsTy env (TySet p ts) = TySet p (map (subsTy env) ts) +subsTy env (TyCase t tps) = TyCase (subsTy env t) (map (bimap (subsTy env) (subsTy env)) tps) +subsTy env (TySig t k) = TySig (subsTy env t) (subsTy env k) +subsTy env (TyExists id k t) = subsTy env t +subsTy env (TyForall id k t) = subsTy env t +subsTy env t = t diff --git a/src/Language/Granule/Codegen/RewriteAST.hs b/src/Language/Granule/Codegen/RewriteAST.hs new file mode 100644 index 0000000..8927a7a --- /dev/null +++ b/src/Language/Granule/Codegen/RewriteAST.hs @@ -0,0 +1,45 @@ +module Language.Granule.Codegen.RewriteAST where + +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type +import Language.Granule.Codegen.RetypeAST + +-- Rewrite Unpack ASTs into App Abs ASTs which our +-- compiler already knows how to handle. WIP. + +rewriteAST :: AST ev Type -> AST ev Type +rewriteAST ast = ast {definitions = map rewriteDef (definitions ast)} + where + rewriteDef def = def {defEquations = rewriteEquationList (defEquations def)} + rewriteEquationList eqs = eqs {equations = map rewriteEquation (equations eqs)} + rewriteEquation eq = eq {equationBody = rewriteExpr (equationBody eq)} + +rewriteExpr :: Expr ev Type -> Expr ev Type +rewriteExpr (Unpack s retTy b tyVar var e1 e2) = + let e1' = rewriteExpr e1 + e1Ty = exprTy e1' + e2' = rewriteExpr e2 + absTy = FunTy Nothing Nothing e1Ty retTy + in fixTypes (App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2')) e1') + where + fixTypes expr = snd $ retypeExpr emptyEnv expr +rewriteExpr (App s a b e1 e2) = App s a b (rewriteExpr e1) (rewriteExpr e2) +rewriteExpr (Val s a b v) = Val s a b (rewriteVal v) +rewriteExpr exp = exp + +rewriteVal :: Value ev Type -> Value ev Type +rewriteVal (Abs a p mt e) = Abs a p mt (rewriteExpr e) +rewriteVal val = val + +exprTy :: Expr ev Type -> Type +exprTy (App _ ty _ _ _) = ty +exprTy (Val _ ty _ _) = ty +exprTy (Binop _ ty _ _ _ _) = ty +exprTy (LetDiamond _ ty _ _ _ _ _) = ty +exprTy (Case _ ty _ _ _) = ty +exprTy (Hole _ ty _ _ _) = ty +exprTy (AppTy _ ty _ _ _) = ty +exprTy (TryCatch _ ty _ _ _ _ _ _) = ty +exprTy (Unpack _ ty _ _ _ _ _) = ty diff --git a/src/Language/Granule/Codegen/StripAST.hs b/src/Language/Granule/Codegen/StripAST.hs new file mode 100644 index 0000000..b4c991d --- /dev/null +++ b/src/Language/Granule/Codegen/StripAST.hs @@ -0,0 +1,110 @@ +module Language.Granule.Codegen.StripAST where + +import Data.Bifunctor (Bifunctor (second), bimap) +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type + +-- Strips types which are not currently needed (or handled) by +-- the compiler, to make life easier and debugging simpler. We are +-- stripping Box, Star, Borrow and type quantifiers, but we may +-- wish to reinstate these to help with future optimisation. WIP. + +stripAST :: AST ev Type -> AST ev Type +stripAST (AST decls defs imports hidden name) = + AST decls (map stripDef defs) imports hidden name + +stripDef :: Def ev Type -> Def ev Type +stripDef (Def s i b spec el ts) = + Def s i b spec (stripEquationList el) (stripTypeScheme ts) + +stripEquationList :: EquationList ev Type -> EquationList ev Type +stripEquationList (EquationList s v b es) = + EquationList s v b (map stripEquation es) + +stripEquation :: Equation ev Type -> Equation ev Type +stripEquation (Equation s n a b ps e) = + Equation s n (stripTy a) b (map stripPat ps) (stripExpr e) + +stripExpr :: Expr ev Type -> Expr ev Type +stripExpr (App s a b e1 e2) = + App s (stripTy a) b (stripExpr e1) (stripExpr e2) +stripExpr (Binop s a b op e1 e2) = + Binop s (stripTy a) b op (stripExpr e1) (stripExpr e2) +stripExpr (LetDiamond s a b p mt e1 e2) = + LetDiamond s (stripTy a) b (stripPat p) (stripMaybeTy mt) (stripExpr e1) (stripExpr e2) +stripExpr (Val s a b v) = + Val s (stripTy a) b (stripVal v) +stripExpr (Case s a b e pes) = + Case s (stripTy a) b (stripExpr e) (map (bimap stripPat stripExpr) pes) +stripExpr (Hole s a b ids hints) = + Hole s (stripTy a) b ids hints +stripExpr (AppTy s a b e t) = + AppTy s (stripTy a) b (stripExpr e) (stripTy t) +stripExpr (TryCatch s a b e1 p mt e2 e3) = + TryCatch s (stripTy a) b (stripExpr e1) (stripPat p) (stripMaybeTy mt) (stripExpr e2) (stripExpr e3) +stripExpr (Unpack s a b tyVar var e1 e2) = + Unpack s (stripTy a) b tyVar var (stripExpr e1) (stripExpr e2) + +stripVal :: Value ev Type -> Value ev Type +stripVal (Var a id) = + Var (stripTy a) id +stripVal (Abs a p mt e) = + Abs (stripTy a) (stripPat p) (stripMaybeTy mt) (stripExpr e) +stripVal (Promote a e) = + Promote (stripTy a) (stripExpr e) +stripVal (Pure a e) = + Pure (stripTy a) (stripExpr e) +stripVal (Constr a id vs) = + Constr (stripTy a) id (map stripVal vs) +stripVal (Ext a ev) = + Ext (stripTy a) ev +stripVal (Nec a e) = + Nec (stripTy a) (stripExpr e) +stripVal (Pack s a t e id k t') = + Pack s (stripTy a) (stripTy t) (stripExpr e) id (stripTy k) (stripTy t') +stripVal (TyAbs a (Left (id, t)) e) = + TyAbs (stripTy a) (Left (id, stripTy t)) (stripExpr e) +stripVal (TyAbs a (Right ids) e) = + TyAbs (stripTy a) (Right ids) (stripExpr e) +stripVal v = v + +stripPat :: Pattern Type -> Pattern Type +stripPat (PVar s a b v) = PVar s (stripTy a) b v +stripPat (PWild s a b) = PWild s (stripTy a) b +stripPat (PBox s a b p) = stripPat p +stripPat (PInt s a b i) = PInt s (stripTy a) b i +stripPat (PFloat s a b f) = PFloat s (stripTy a) b f +stripPat (PConstr s a b id ids ps) = PConstr s (stripTy a) b id ids (map stripPat ps) + +stripTypeScheme :: TypeScheme -> TypeScheme +stripTypeScheme (Forall s quants constraints t) = + Forall + s + (map (second stripTy) quants) + (map stripTy constraints) + (stripTy t) + +stripMaybeTy :: Maybe Type -> Maybe Type +stripMaybeTy Nothing = Nothing +stripMaybeTy (Just ty) = Just (stripTy ty) + +stripTy :: Type -> Type +stripTy (Type i) = Type i +stripTy (FunTy id mc arg ret) = FunTy id (stripMaybeTy mc) (stripTy arg) (stripTy ret) +stripTy (TyCon id) = TyCon id +stripTy (Box c t) = stripTy t +stripTy (Diamond e t) = Diamond (stripTy e) (stripTy t) +stripTy (Star g t) = stripTy t +stripTy (Borrow p t) = stripTy t +stripTy (TyVar id) = TyVar id +stripTy (TyApp t1 t2) = TyApp (stripTy t1) (stripTy t2) +stripTy (TyGrade mt i) = TyGrade (stripMaybeTy mt) i +stripTy (TyInfix op t1 t2) = TyInfix op (stripTy t1) (stripTy t2) +stripTy (TySet p ts) = TySet p (map stripTy ts) +stripTy (TyCase t tps) = TyCase (stripTy t) (map (bimap stripTy stripTy) tps) +stripTy (TySig t k) = TySig (stripTy t) (stripTy k) +stripTy (TyExists id k t) = stripTy t +stripTy (TyForall id k t) = stripTy t +stripTy t = t From daf22ea60231eac4f578a3f161829ceeb3b0dd8f Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Thu, 6 Mar 2025 09:55:28 +0000 Subject: [PATCH 13/25] less incorrect (was having type errors on M1) --- .../Codegen/Builtins/ImmutableArray.hs | 19 +++++++++++-------- .../Granule/Codegen/Builtins/MutableArray.hs | 13 ++++++++----- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs b/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs index 9d177ea..f2ce3ff 100644 --- a/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs +++ b/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs @@ -23,13 +23,15 @@ newFloatArrayIDef = arrPtr' <- bitcast arrPtr (ptr structTy) -- length * double precision 8 bytes dataSize <- mul len (int32 8) - dataPtr <- call (ConstantOperand malloc) [(dataSize, [])] + dataSize64 <- sext dataSize i64 + dataPtr <- call (ConstantOperand malloc) [(dataSize64, [])] + dataPtr' <- bitcast dataPtr (ptr IR.double) lenField <- gep arrPtr' [int32 0, int32 0] store lenField 0 len dataField <- gep arrPtr' [int32 0, int32 1] - store dataField 0 dataPtr + store dataField 0 dataPtr' return arrPtr @@ -79,13 +81,14 @@ writeFloatArrayIDef = newArrPtr' <- bitcast newArrPtr (ptr structTy) dataSize <- mul len (int32 8) - newDataPtr <- call (ConstantOperand malloc) [(dataSize, [])] + dataSize64 <- sext dataSize i64 + newDataPtr <- call (ConstantOperand malloc) [(dataSize64, [])] newDataPtr' <- bitcast newDataPtr (ptr IR.double) - - -- copy the existing data to new array + dataPtr' <- bitcast dataPtr (ptr i8) + newDataPtr'' <- bitcast newDataPtr' (ptr i8) _ <- call (ConstantOperand memcpy) - [ (newDataPtr, []) - , (dataPtr, []) + [ (newDataPtr'', []) + , (dataPtr', []) , (dataSize, []) , (bit 0, []) ] @@ -98,7 +101,7 @@ writeFloatArrayIDef = store newLenField 0 len newDataField <- gep newArrPtr' [int32 0, int32 1] - store newDataField 0 newDataPtr + store newDataField 0 newDataPtr' return newArrPtr diff --git a/src/Language/Granule/Codegen/Builtins/MutableArray.hs b/src/Language/Granule/Codegen/Builtins/MutableArray.hs index 3eb0ad8..0ed8205 100644 --- a/src/Language/Granule/Codegen/Builtins/MutableArray.hs +++ b/src/Language/Granule/Codegen/Builtins/MutableArray.hs @@ -23,13 +23,15 @@ newFloatArrayDef = arrPtr' <- bitcast arrPtr (ptr structTy) -- length * double precision 8 bytes dataSize <- mul len (int32 8) - dataPtr <- call (ConstantOperand malloc) [(dataSize, [])] + dataSize64 <- sext dataSize i64 + dataPtr <- call (ConstantOperand malloc) [(dataSize64, [])] + dataPtr' <- bitcast dataPtr (ptr IR.double) lenField <- gep arrPtr' [int32 0, int32 0] store lenField 0 len dataField <- gep arrPtr' [int32 0, int32 1] - store dataField 0 dataPtr + store dataField 0 dataPtr' return arrPtr @@ -104,12 +106,13 @@ deleteFloatArrayDef = arrPtr' <- bitcast arrPtr (ptr structTy) dataField <- gep arrPtr' [int32 0, int32 1] dataPtr <- load dataField 0 + dataPtr' <- bitcast dataPtr (ptr i8) -- free the data - _ <- call (ConstantOperand free) [(dataPtr, [])] + _ <- call (ConstantOperand free) [(dataPtr', [])] - -- free the array pair - _ <- call (ConstantOperand free) [(arrPtr, [])] + arrPtr'' <- bitcast arrPtr (ptr i8) + _ <- call (ConstantOperand free) [(arrPtr'', [])] -- return unit (need to check) return $ ConstantOperand (C.Struct Nothing False []) From 6e7e3b33addf426e85d0dcc73d5483565ac93d50 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Thu, 6 Mar 2025 12:25:25 +0000 Subject: [PATCH 14/25] handle more ast --- src/Language/Granule/Codegen/PrintAST.hs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/Language/Granule/Codegen/PrintAST.hs b/src/Language/Granule/Codegen/PrintAST.hs index c1041f5..83180ca 100644 --- a/src/Language/Granule/Codegen/PrintAST.hs +++ b/src/Language/Granule/Codegen/PrintAST.hs @@ -5,10 +5,11 @@ import Language.Granule.Syntax.Expr import Language.Granule.Syntax.Identifiers (Id (Id)) import Language.Granule.Syntax.Pattern import Language.Granule.Syntax.Type +import Data.Bifunctor (bimap) -- converts an AST to JSON for debugging & visualisation. WIP -data JSON = Str String | Obj [(String, JSON)] | Arr [JSON] +data JSON = Str String | Obj [(String, JSON)] | Arr [JSON] | Pair (JSON, JSON) printAST :: AST ev Type -> String printAST ast = toString (jAST ast) @@ -17,6 +18,7 @@ toString :: JSON -> String toString (Str s) = "\"" ++ s ++ "\"" toString (Obj pairs) = "{" ++ commas (map field pairs) ++ "}" toString (Arr elems) = "[" ++ commas (map toString elems) ++ "]" +toString (Pair (left, right)) = toString (Arr [left, right]) field :: (String, JSON) -> String field (key, value) = "\"" ++ key ++ "\":" ++ toString value @@ -43,16 +45,16 @@ jEq (Equation s i a b ps expr) = jExpr expr jExpr :: Expr ev Type -> JSON jExpr (App _ ty _ fn arg) = node "App" [("type", jTy ty), ("fn", jExpr fn), ("arg", jExpr arg)] -jExpr (Binop {}) = Str "%BINOP%" +jExpr (Binop _ ty _ op e1 e2) = node "Binop" [("type", jTy ty), ("op", Str (show op)), ("e1", jExpr e1), ("e2", jExpr e2)] jExpr (LetDiamond {}) = Str "%LETDIAMOND%" -jExpr (Val _ ty _ (Promote ty' expr)) = jExpr expr jExpr (Val _ ty _ val) = node "Val" [("type", jTy ty), ("val", jVal val)] -jExpr (Case {}) = Str "%CASE%" +jExpr (Case s ty b e ps) = node "Case" [("type", jTy ty), ("expr", jExpr e), ("cases", Arr (map (Pair . bimap jPat jExpr) ps))] jExpr (Hole {}) = Str "%HOLE%" jExpr (AppTy {}) = Str "%APP_TY%" jExpr (TryCatch {}) = Str "%TRY_CATCH%" jExpr (Unpack _ ty _ _ id e1 e2) = node "Unpack" [("type", jTy ty), ("id", jId id), ("e1", jExpr e1), ("e2", jExpr e2)] + jVal :: Value ev Type -> JSON jVal (Var ty id) = node "Var" [("type", jTy ty), ("id", jId id)] jVal (Abs ty pat mty expr) = node "Abs" [("type", jTy ty), ("pattern", jPat pat), ("expr", jExpr expr)] @@ -70,10 +72,10 @@ jVal (TyAbs {}) = Str "%TY_ABS%" jPat :: Pattern Type -> JSON jPat (PVar _ ty _ id) = node "PVar" [("type", jTy ty), ("id", jId id)] -jPat (PWild {}) = Str "%PWILD%" +jPat (PWild s a b) = Str "_" jPat (PBox _ ty _ p) = node "PBox" [("ty", jTy ty), ("pat", jPat p)] -jPat (PInt {}) = Str "%PINT%" -jPat (PFloat {}) = Str "%PFLOAT%" +jPat (PInt s a b v) = Str (show v) +jPat (PFloat s a b v) = Str (show v) jPat (PConstr _ ty _ id _ pats) = node "PConstr" [("type", jTy ty), ("id", jId id), ("pats", Arr (map jPat pats))] jId :: Id -> JSON @@ -110,7 +112,6 @@ sTy (TyInfix {}) = "%TY_INFIX%" sTy (TySet {}) = "%TY_SET%" sTy (TyCase {}) = "%TY_CASE%" sTy (TySig {}) = "%TY_SIG%" --- sTy (TyExists id kind ty) = "exists {" ++ sId id ++ " : " ++ sTy kind ++ "} . " ++ sTy ty -sTy (TyExists id kind ty) = sTy ty +sTy (TyExists id kind ty) = "exists {" ++ sId id ++ " : " ++ sTy kind ++ "} . " ++ sTy ty sTy (TyForall id kind ty) = "forall {" ++ sId id ++ " : " ++ sTy kind ++ "} . " ++ sTy ty sTy (TyName {}) = "%TY_NAME%" From 5acd943613b87edbb5e3e0a55f5884d05ffb4a71 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Thu, 6 Mar 2025 12:26:14 +0000 Subject: [PATCH 15/25] more ast --- src/Language/Granule/Codegen/RetypeAST.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Language/Granule/Codegen/RetypeAST.hs b/src/Language/Granule/Codegen/RetypeAST.hs index 4f0b6f8..7933b25 100644 --- a/src/Language/Granule/Codegen/RetypeAST.hs +++ b/src/Language/Granule/Codegen/RetypeAST.hs @@ -71,6 +71,7 @@ retypeVal env (Constr ty id vals) = in (env', Constr ty' id vals') retypeVal env (NumInt v) = (env, NumInt v) retypeVal env (NumFloat v) = (env, NumFloat v) +retypeVal env (Promote t v) = (env, Promote t v) retypeVal env val = error "TODO val" retypePat :: Env -> Pattern Type -> (Env, Pattern Type) From c30594d246a0a1a9aa657068852a03ffad224c2f Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Thu, 6 Mar 2025 12:26:41 +0000 Subject: [PATCH 16/25] we only need to rewrite unpack then fix types --- src/Language/Granule/Codegen/RewriteAST.hs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/Language/Granule/Codegen/RewriteAST.hs b/src/Language/Granule/Codegen/RewriteAST.hs index 8927a7a..f3f52fe 100644 --- a/src/Language/Granule/Codegen/RewriteAST.hs +++ b/src/Language/Granule/Codegen/RewriteAST.hs @@ -18,20 +18,15 @@ rewriteAST ast = ast {definitions = map rewriteDef (definitions ast)} rewriteExpr :: Expr ev Type -> Expr ev Type rewriteExpr (Unpack s retTy b tyVar var e1 e2) = - let e1' = rewriteExpr e1 + let e1' = e1 e1Ty = exprTy e1' - e2' = rewriteExpr e2 + e2' = e2 absTy = FunTy Nothing Nothing e1Ty retTy in fixTypes (App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2')) e1') where fixTypes expr = snd $ retypeExpr emptyEnv expr -rewriteExpr (App s a b e1 e2) = App s a b (rewriteExpr e1) (rewriteExpr e2) -rewriteExpr (Val s a b v) = Val s a b (rewriteVal v) rewriteExpr exp = exp -rewriteVal :: Value ev Type -> Value ev Type -rewriteVal (Abs a p mt e) = Abs a p mt (rewriteExpr e) -rewriteVal val = val exprTy :: Expr ev Type -> Type exprTy (App _ ty _ _ _) = ty From 07dfd24405e5c04426358c79769a34feaca27a08 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Mon, 24 Mar 2025 00:39:24 +0000 Subject: [PATCH 17/25] shared helpers for arrays, can be used by non-float arrays (strings) --- .../Granule/Codegen/Builtins/Builtins.hs | 5 +- .../Codegen/Builtins/ImmutableArray.hs | 82 +++---------------- .../Granule/Codegen/Builtins/MutableArray.hs | 73 ++++------------- .../Granule/Codegen/Builtins/Shared.hs | 75 +++++++++++++++-- 4 files changed, 102 insertions(+), 133 deletions(-) diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs index 13eb691..00f5530 100644 --- a/src/Language/Granule/Codegen/Builtins/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -9,7 +9,10 @@ import Language.Granule.Syntax.Identifiers (Id, mkId) builtins :: [Builtin] builtins = [ charToIntDef, divDef, - newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef, + newFloatArrayIDef, + readFloatArrayIDef, + writeFloatArrayIDef, + lengthFloatArrayIDef, newFloatArrayDef, readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, deleteFloatArrayDef ] diff --git a/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs b/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs index f2ce3ff..d85cbeb 100644 --- a/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs +++ b/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs @@ -2,14 +2,10 @@ module Language.Granule.Codegen.Builtins.ImmutableArray where -import LLVM.AST import LLVM.AST.Type as IR -import qualified LLVM.AST.Constant as C import LLVM.IRBuilder.Constant as C import LLVM.IRBuilder.Instruction import Language.Granule.Codegen.Builtins.Shared -import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) -import Language.Granule.Codegen.Emit.Primitives (malloc, memcpy) -- newFloatArrayI :: Int -> FloatArray id newFloatArrayIDef :: Builtin @@ -19,21 +15,11 @@ newFloatArrayIDef = args = [tyInt] ret = tyFloatArray impl [len] = do - arrPtr <- call (ConstantOperand malloc) [(ConstantOperand $ sizeOf structTy, [])] - arrPtr' <- bitcast arrPtr (ptr structTy) -- length * double precision 8 bytes dataSize <- mul len (int32 8) dataSize64 <- sext dataSize i64 - dataPtr <- call (ConstantOperand malloc) [(dataSize64, [])] - dataPtr' <- bitcast dataPtr (ptr IR.double) - - lenField <- gep arrPtr' [int32 0, int32 0] - store lenField 0 len - - dataField <- gep arrPtr' [int32 0, int32 1] - store dataField 0 dataPtr' - - return arrPtr + dataPtr <- allocate dataSize64 IR.double + makeArray IR.double len dataPtr -- readFloatArrayI :: (FloatArray id) -> Int -> (Float, FloatArray id) readFloatArrayIDef :: Builtin @@ -43,22 +29,10 @@ readFloatArrayIDef = args = [tyFloatArray, tyInt] ret = tyPair (tyFloat, tyFloatArray) impl [arrPtr, idx] = do - -- arr -> data -> idx -> val - - arrPtr' <- bitcast arrPtr (ptr structTy) - - dataField <- gep arrPtr' [int32 0, int32 1] - dataPtr <- load dataField 0 - + dataPtr <- getArrayDataPtr arrPtr valuePtr <- gep dataPtr [idx] value <- load valuePtr 0 - - let pairTy = StructureType False [IR.double, ptr structTy] - let pair = ConstantOperand $ C.Undef pairTy - pair' <- insertValue pair value [0] - insertValue pair' arrPtr [1] - - + makePair (IR.double, value) (ptr floatArrayStruct, arrPtr) -- writeFloatArrayI :: (FloatArray id) -> Int -> Float -> FloatArray id writeFloatArrayIDef :: Builtin @@ -68,42 +42,19 @@ writeFloatArrayIDef = args = [tyFloatArray, tyInt, tyFloat] ret = tyFloatArray impl [arrPtr, idx, val] = do - arrPtr' <- bitcast arrPtr (ptr structTy) - - lenField <- gep arrPtr' [int32 0, int32 0] - len <- load lenField 0 - - dataField <- gep arrPtr' [int32 0, int32 1] - dataPtr <- load dataField 0 - - -- need to create a new array as in newFloatArrayI - newArrPtr <- call (ConstantOperand malloc) [(ConstantOperand $ sizeOf structTy, [])] - newArrPtr' <- bitcast newArrPtr (ptr structTy) + len <- getArrayLen arrPtr + dataPtr <- getArrayDataPtr arrPtr dataSize <- mul len (int32 8) dataSize64 <- sext dataSize i64 - newDataPtr <- call (ConstantOperand malloc) [(dataSize64, [])] - newDataPtr' <- bitcast newDataPtr (ptr IR.double) - dataPtr' <- bitcast dataPtr (ptr i8) - newDataPtr'' <- bitcast newDataPtr' (ptr i8) - _ <- call (ConstantOperand memcpy) - [ (newDataPtr'', []) - , (dataPtr', []) - , (dataSize, []) - , (bit 0, []) - ] - - -- write the val to new copy - valuePtr <- gep newDataPtr' [idx] - store valuePtr 0 val + newDataPtr <- allocateFloatArray len - newLenField <- gep newArrPtr' [int32 0, int32 0] - store newLenField 0 len + _ <- copy newDataPtr dataPtr dataSize - newDataField <- gep newArrPtr' [int32 0, int32 1] - store newDataField 0 newDataPtr' + valuePtr <- gep newDataPtr [idx] + store valuePtr 0 val - return newArrPtr + makeArray IR.double len newDataPtr -- lengthFloatArrayI :: (FloatArray id) -> (Int -> FloatArray id) lengthFloatArrayIDef :: Builtin @@ -113,12 +64,5 @@ lengthFloatArrayIDef = args = [tyFloatArray] ret = tyPair (tyInt, tyFloatArray) impl [arrPtr] = do - arrPtr' <- bitcast arrPtr (ptr structTy) - - lenField <- gep arrPtr' [int32 0, int32 0] - len <- load lenField 0 - - let pairTy = StructureType False [i32, ptr structTy] - let pair = ConstantOperand $ C.Undef pairTy - pair' <- insertValue pair len [0] - insertValue pair' arrPtr [1] + len <- getArrayLen arrPtr + makePair (i32, len) (ptr floatArrayStruct, arrPtr) diff --git a/src/Language/Granule/Codegen/Builtins/MutableArray.hs b/src/Language/Granule/Codegen/Builtins/MutableArray.hs index 0ed8205..1c433bf 100644 --- a/src/Language/Granule/Codegen/Builtins/MutableArray.hs +++ b/src/Language/Granule/Codegen/Builtins/MutableArray.hs @@ -8,8 +8,6 @@ import qualified LLVM.AST.Constant as C import LLVM.IRBuilder.Constant as C import LLVM.IRBuilder.Instruction import Language.Granule.Codegen.Builtins.Shared -import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) -import Language.Granule.Codegen.Emit.Primitives (malloc, free) -- newFloatArray :: Int -> FloatArray id newFloatArrayDef :: Builtin @@ -19,21 +17,11 @@ newFloatArrayDef = args = [tyInt] ret = tyFloatArray impl [len] = do - arrPtr <- call (ConstantOperand malloc) [(ConstantOperand $ sizeOf structTy, [])] - arrPtr' <- bitcast arrPtr (ptr structTy) - -- length * double precision 8 bytes - dataSize <- mul len (int32 8) - dataSize64 <- sext dataSize i64 - dataPtr <- call (ConstantOperand malloc) [(dataSize64, [])] - dataPtr' <- bitcast dataPtr (ptr IR.double) - - lenField <- gep arrPtr' [int32 0, int32 0] - store lenField 0 len - - dataField <- gep arrPtr' [int32 0, int32 1] - store dataField 0 dataPtr' - - return arrPtr + -- length * double precision 8 bytes + dataSize <- mul len (int32 8) + dataSize64 <- sext dataSize i64 + dataPtr <- allocate dataSize64 IR.double + makeArray IR.double len dataPtr -- readFloatArray :: (FloatArray id) -> Int -> (Float, FloatArray id) readFloatArrayDef :: Builtin @@ -43,22 +31,10 @@ readFloatArrayDef = args = [tyFloatArray, tyInt] ret = tyPair (tyFloat, tyFloatArray) impl [arrPtr, idx] = do - -- arr -> data -> idx -> val - - arrPtr' <- bitcast arrPtr (ptr structTy) - - dataField <- gep arrPtr' [int32 0, int32 1] - dataPtr <- load dataField 0 - - valuePtr <- gep dataPtr [idx] - value <- load valuePtr 0 - - let pairTy = StructureType False [IR.double, ptr structTy] - let pair = ConstantOperand $ C.Undef pairTy - pair' <- insertValue pair value [0] - insertValue pair' arrPtr [1] - - + dataPtr <- getArrayDataPtr arrPtr + valuePtr <- gep dataPtr [idx] + value <- load valuePtr 0 + makePair (IR.double, value) (ptr floatArrayStruct, arrPtr) -- writeFloatArray :: (FloatArray id) -> Int -> Float -> FloatArray id writeFloatArrayDef :: Builtin @@ -68,14 +44,9 @@ writeFloatArrayDef = args = [tyFloatArray, tyInt, tyFloat] ret = tyFloatArray impl [arrPtr, idx, val] = do - arrPtr' <- bitcast arrPtr (ptr structTy) - - dataField <- gep arrPtr' [int32 0, int32 1] - dataPtr <- load dataField 0 - + dataPtr <- getArrayDataPtr arrPtr valuePtr <- gep dataPtr [idx] store valuePtr 0 val - return arrPtr -- lengthFloatArray :: (FloatArray id) -> (Int -> FloatArray id) @@ -86,15 +57,8 @@ lengthFloatArrayDef = args = [tyFloatArray] ret = tyPair (tyInt, tyFloatArray) impl [arrPtr] = do - arrPtr' <- bitcast arrPtr (ptr structTy) - - lenField <- gep arrPtr' [int32 0, int32 0] - len <- load lenField 0 - - let pairTy = StructureType False [i32, ptr structTy] - let pair = ConstantOperand $ C.Undef pairTy - pair' <- insertValue pair len [0] - insertValue pair' arrPtr [1] + len <- getArrayLen arrPtr + makePair (i32, len) (ptr floatArrayStruct, arrPtr) deleteFloatArrayDef :: Builtin deleteFloatArrayDef = @@ -103,16 +67,9 @@ deleteFloatArrayDef = args = [tyFloatArray] ret = tyUnit impl [arrPtr] = do - arrPtr' <- bitcast arrPtr (ptr structTy) - dataField <- gep arrPtr' [int32 0, int32 1] - dataPtr <- load dataField 0 - dataPtr' <- bitcast dataPtr (ptr i8) - - -- free the data - _ <- call (ConstantOperand free) [(dataPtr', [])] - - arrPtr'' <- bitcast arrPtr (ptr i8) - _ <- call (ConstantOperand free) [(arrPtr'', [])] + dataPtr <- getArrayDataPtr arrPtr + _ <- free dataPtr + _ <- free arrPtr -- return unit (need to check) return $ ConstantOperand (C.Struct Nothing False []) diff --git a/src/Language/Granule/Codegen/Builtins/Shared.hs b/src/Language/Granule/Codegen/Builtins/Shared.hs index 72a3aa8..2836c58 100644 --- a/src/Language/Granule/Codegen/Builtins/Shared.hs +++ b/src/Language/Granule/Codegen/Builtins/Shared.hs @@ -5,11 +5,15 @@ module Language.Granule.Codegen.Builtins.Shared where import LLVM.AST import LLVM.AST.Type as IR +import LLVM.IRBuilder.Constant (bit, int32) +import LLVM.IRBuilder.Instruction (bitcast, call, gep, load, store, insertValue, mul, sext) import LLVM.IRBuilder.Module import LLVM.IRBuilder.Monad - +import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) +import Language.Granule.Codegen.Emit.Primitives as P import Language.Granule.Syntax.Identifiers import Language.Granule.Syntax.Type as Gr +import qualified LLVM.AST.Constant as C data Builtin = Builtin { builtinId :: String, @@ -17,14 +21,75 @@ data Builtin = Builtin { builtinRetTy :: Gr.Type, builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} --- helpers +-- LLVM helpers + +allocate :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> IR.Type -> m Operand +allocate len ty = do + pointer <- call (ConstantOperand P.malloc) [(len, [])] + bitcast pointer (ptr ty) + +allocateStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => IR.Type -> m Operand +allocateStruct ty = allocate (ConstantOperand $ sizeOf ty) ty + +allocateFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +allocateFloatArray len = do + dataSize <- mul len (int32 8) + dataSize64 <- sext dataSize i64 + allocate dataSize64 IR.double + +copy :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> Operand -> m Operand +copy dst src len = do + dst' <- bitcast dst (ptr i8) + src' <- bitcast src (ptr i8) + call (ConstantOperand P.memcpy) [(dst', []), (src', []), (len, []), (bit 0, [])] + +free :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +free memPtr = do + memPtr' <- bitcast memPtr (ptr i8) + call (ConstantOperand P.free) [(memPtr', [])] + +makePair :: (MonadIRBuilder m, MonadModuleBuilder m) => (IR.Type, Operand) -> (IR.Type, Operand) -> m Operand +makePair (leftTy, leftVal) (rightTy, rightVal) = do + let pairTy = StructureType False [leftTy, rightTy] + let pair = ConstantOperand $ C.Undef pairTy + pair' <- insertValue pair leftVal [0] + insertValue pair' rightVal [1] + +-- Arrays + +arrayStruct :: IR.Type -> IR.Type +arrayStruct ty = StructureType False [i32, ptr ty] + +floatArrayStruct :: IR.Type +floatArrayStruct = arrayStruct IR.double + +getArrayLen :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +getArrayLen arrPtr = do + lenField <- gep arrPtr [int32 0, int32 0] + load lenField 0 + +getArrayDataPtr :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +getArrayDataPtr arrPtr = do + dataField <- gep arrPtr [int32 0, int32 1] + load dataField 0 + +makeArray :: (MonadIRBuilder m, MonadModuleBuilder m) => IR.Type -> Operand -> Operand -> m Operand +makeArray ty len dataPtr = do + arrPtr <- allocateStruct (arrayStruct ty) + + lenField <- gep arrPtr [int32 0, int32 0] + store lenField 0 len + + dataField <- gep arrPtr [int32 0, int32 1] + store dataField 0 dataPtr + + return arrPtr + +-- Granule types mkFunType :: [Gr.Type] -> Gr.Type -> Gr.Type mkFunType args ret = foldr (FunTy Nothing Nothing) ret args -structTy :: IR.Type -structTy = StructureType False [i32, ptr IR.double] - tyInt :: Gr.Type tyInt = TyCon (Id "Int" "Int") From 2ee307eff6ce50a97cc11c9def4a4c76c04aec67 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Mon, 24 Mar 2025 01:39:04 +0000 Subject: [PATCH 18/25] cleanup floatarray shared implementation --- .../Granule/Codegen/Builtins/Builtins.hs | 18 +-- .../Granule/Codegen/Builtins/FloatArray.hs | 106 ++++++++++++++++++ .../Codegen/Builtins/ImmutableArray.hs | 68 ----------- .../Granule/Codegen/Builtins/MutableArray.hs | 75 ------------- .../Granule/Codegen/Builtins/Shared.hs | 53 +++++---- 5 files changed, 142 insertions(+), 178 deletions(-) create mode 100644 src/Language/Granule/Codegen/Builtins/FloatArray.hs delete mode 100644 src/Language/Granule/Codegen/Builtins/ImmutableArray.hs delete mode 100644 src/Language/Granule/Codegen/Builtins/MutableArray.hs diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs index 00f5530..f4aba6e 100644 --- a/src/Language/Granule/Codegen/Builtins/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -1,20 +1,24 @@ module Language.Granule.Codegen.Builtins.Builtins where -import Language.Granule.Codegen.Builtins.Shared -import Language.Granule.Codegen.Builtins.ImmutableArray -import Language.Granule.Codegen.Builtins.MutableArray import Language.Granule.Codegen.Builtins.Extras +import Language.Granule.Codegen.Builtins.FloatArray +import Language.Granule.Codegen.Builtins.Shared import Language.Granule.Syntax.Identifiers (Id, mkId) builtins :: [Builtin] -builtins = [ - charToIntDef, divDef, +builtins = + [ charToIntDef, + divDef, newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef, - newFloatArrayDef, readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, deleteFloatArrayDef - ] + newFloatArrayDef, + readFloatArrayDef, + writeFloatArrayDef, + lengthFloatArrayDef, + deleteFloatArrayDef + ] builtinIds :: [Id] builtinIds = map (mkId . builtinId) builtins diff --git a/src/Language/Granule/Codegen/Builtins/FloatArray.hs b/src/Language/Granule/Codegen/Builtins/FloatArray.hs new file mode 100644 index 0000000..b6d8acb --- /dev/null +++ b/src/Language/Granule/Codegen/Builtins/FloatArray.hs @@ -0,0 +1,106 @@ +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module Language.Granule.Codegen.Builtins.FloatArray where + +import qualified LLVM.AST.Constant as C +import LLVM.AST.Operand (Operand (ConstantOperand)) +import LLVM.AST.Type as IR +import LLVM.IRBuilder.Constant as C +import LLVM.IRBuilder.Instruction +import LLVM.IRBuilder.Module (MonadModuleBuilder) +import LLVM.IRBuilder.Monad (MonadIRBuilder) +import Language.Granule.Codegen.Builtins.Shared + +-- Mutable FloatArray builtins +newFloatArrayDef, readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, deleteFloatArrayDef :: Builtin +newFloatArrayDef = + Builtin "newFloatArray" [tyInt] tyFloatArray impl + where + impl [len] = newFloatArray len +readFloatArrayDef = + Builtin "readFloatArray" [tyFloatArray, tyInt] (tyPair (tyFloat, tyFloatArray)) impl + where + impl [arrPtr, idx] = readFloatArray arrPtr idx +writeFloatArrayDef = + Builtin "writeFloatArray" [tyFloatArray, tyInt, tyFloat] tyFloatArray impl + where + impl [arrPtr, idx, val] = do + dataPtr <- readStruct arrPtr 1 + writeData dataPtr idx val + return arrPtr +lengthFloatArrayDef = + Builtin "lengthFloatArray" [tyFloatArray] (tyPair (tyInt, tyFloatArray)) impl + where + impl [arrPtr] = lengthFloatArray arrPtr +deleteFloatArrayDef = + Builtin "deleteFloatArray" [tyFloatArray] tyUnit impl + where + impl [arrPtr] = do + dataPtr <- readStruct arrPtr 1 + _ <- free dataPtr + _ <- free arrPtr + -- return unit (need to check) + return $ ConstantOperand (C.Struct Nothing False []) + +-- Immutable FloatArray builtins +newFloatArrayIDef, readFloatArrayIDef, writeFloatArrayIDef, lengthFloatArrayIDef :: Builtin +newFloatArrayIDef = + Builtin "newFloatArrayI" [tyInt] tyFloatArray impl + where + impl [len] = newFloatArray len +readFloatArrayIDef = + Builtin "readFloatArrayI" [tyFloatArray, tyInt] (tyPair (tyFloat, tyFloatArray)) impl + where + impl [arrPtr, idx] = readFloatArray arrPtr idx +writeFloatArrayIDef = + Builtin "writeFloatArrayI" [tyFloatArray, tyInt, tyFloat] tyFloatArray impl + where + impl [arrPtr, idx, val] = do + len <- readStruct arrPtr 0 + dataPtr <- readStruct arrPtr 1 + + -- len * double precision + size <- mul len (int32 8) + -- malloc wants i64 + size' <- sext size i64 + newDataPtr <- allocate size' IR.double + + -- copy existing data to new array + _ <- copy newDataPtr dataPtr size + + -- write value at index + writeData newDataPtr idx val + + -- return a new array struct + makeArrayStruct IR.double len newDataPtr +lengthFloatArrayIDef = + Builtin "lengthFloatArrayI" [tyFloatArray] (tyPair (tyInt, tyFloatArray)) impl + where + impl [arrPtr] = lengthFloatArray arrPtr + +-- arrayStruct specialisation +floatArrayStruct :: IR.Type +floatArrayStruct = arrayStruct IR.double + +-- creates a new float array of length +newFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +newFloatArray len = do + -- len * double precision + size <- mul len (int32 8) + -- malloc wants i64 + size' <- sext size i64 + dataPtr <- allocate size' IR.double + makeArrayStruct IR.double len dataPtr + +-- read the value at index of float array +readFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> m Operand +readFloatArray arrPtr idx = do + dataPtr <- readStruct arrPtr 1 + value <- readData dataPtr idx + makePair (IR.double, value) (ptr floatArrayStruct, arrPtr) + +-- read the length of float array +lengthFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand +lengthFloatArray arrPtr = do + len <- readStruct arrPtr 0 + makePair (i32, len) (ptr floatArrayStruct, arrPtr) diff --git a/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs b/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs deleted file mode 100644 index d85cbeb..0000000 --- a/src/Language/Granule/Codegen/Builtins/ImmutableArray.hs +++ /dev/null @@ -1,68 +0,0 @@ -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} - -module Language.Granule.Codegen.Builtins.ImmutableArray where - -import LLVM.AST.Type as IR -import LLVM.IRBuilder.Constant as C -import LLVM.IRBuilder.Instruction -import Language.Granule.Codegen.Builtins.Shared - --- newFloatArrayI :: Int -> FloatArray id -newFloatArrayIDef :: Builtin -newFloatArrayIDef = - Builtin "newFloatArrayI" args ret impl - where - args = [tyInt] - ret = tyFloatArray - impl [len] = do - -- length * double precision 8 bytes - dataSize <- mul len (int32 8) - dataSize64 <- sext dataSize i64 - dataPtr <- allocate dataSize64 IR.double - makeArray IR.double len dataPtr - --- readFloatArrayI :: (FloatArray id) -> Int -> (Float, FloatArray id) -readFloatArrayIDef :: Builtin -readFloatArrayIDef = - Builtin "readFloatArrayI" args ret impl - where - args = [tyFloatArray, tyInt] - ret = tyPair (tyFloat, tyFloatArray) - impl [arrPtr, idx] = do - dataPtr <- getArrayDataPtr arrPtr - valuePtr <- gep dataPtr [idx] - value <- load valuePtr 0 - makePair (IR.double, value) (ptr floatArrayStruct, arrPtr) - --- writeFloatArrayI :: (FloatArray id) -> Int -> Float -> FloatArray id -writeFloatArrayIDef :: Builtin -writeFloatArrayIDef = - Builtin "writeFloatArrayI" args ret impl - where - args = [tyFloatArray, tyInt, tyFloat] - ret = tyFloatArray - impl [arrPtr, idx, val] = do - len <- getArrayLen arrPtr - dataPtr <- getArrayDataPtr arrPtr - - dataSize <- mul len (int32 8) - dataSize64 <- sext dataSize i64 - newDataPtr <- allocateFloatArray len - - _ <- copy newDataPtr dataPtr dataSize - - valuePtr <- gep newDataPtr [idx] - store valuePtr 0 val - - makeArray IR.double len newDataPtr - --- lengthFloatArrayI :: (FloatArray id) -> (Int -> FloatArray id) -lengthFloatArrayIDef :: Builtin -lengthFloatArrayIDef = - Builtin "lengthFloatArrayI" args ret impl - where - args = [tyFloatArray] - ret = tyPair (tyInt, tyFloatArray) - impl [arrPtr] = do - len <- getArrayLen arrPtr - makePair (i32, len) (ptr floatArrayStruct, arrPtr) diff --git a/src/Language/Granule/Codegen/Builtins/MutableArray.hs b/src/Language/Granule/Codegen/Builtins/MutableArray.hs deleted file mode 100644 index 1c433bf..0000000 --- a/src/Language/Granule/Codegen/Builtins/MutableArray.hs +++ /dev/null @@ -1,75 +0,0 @@ -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} - -module Language.Granule.Codegen.Builtins.MutableArray where - -import LLVM.AST -import LLVM.AST.Type as IR -import qualified LLVM.AST.Constant as C -import LLVM.IRBuilder.Constant as C -import LLVM.IRBuilder.Instruction -import Language.Granule.Codegen.Builtins.Shared - --- newFloatArray :: Int -> FloatArray id -newFloatArrayDef :: Builtin -newFloatArrayDef = - Builtin "newFloatArray" args ret impl - where - args = [tyInt] - ret = tyFloatArray - impl [len] = do - -- length * double precision 8 bytes - dataSize <- mul len (int32 8) - dataSize64 <- sext dataSize i64 - dataPtr <- allocate dataSize64 IR.double - makeArray IR.double len dataPtr - --- readFloatArray :: (FloatArray id) -> Int -> (Float, FloatArray id) -readFloatArrayDef :: Builtin -readFloatArrayDef = - Builtin "readFloatArray" args ret impl - where - args = [tyFloatArray, tyInt] - ret = tyPair (tyFloat, tyFloatArray) - impl [arrPtr, idx] = do - dataPtr <- getArrayDataPtr arrPtr - valuePtr <- gep dataPtr [idx] - value <- load valuePtr 0 - makePair (IR.double, value) (ptr floatArrayStruct, arrPtr) - --- writeFloatArray :: (FloatArray id) -> Int -> Float -> FloatArray id -writeFloatArrayDef :: Builtin -writeFloatArrayDef = - Builtin "writeFloatArray" args ret impl - where - args = [tyFloatArray, tyInt, tyFloat] - ret = tyFloatArray - impl [arrPtr, idx, val] = do - dataPtr <- getArrayDataPtr arrPtr - valuePtr <- gep dataPtr [idx] - store valuePtr 0 val - return arrPtr - --- lengthFloatArray :: (FloatArray id) -> (Int -> FloatArray id) -lengthFloatArrayDef :: Builtin -lengthFloatArrayDef = - Builtin "lengthFloatArray" args ret impl - where - args = [tyFloatArray] - ret = tyPair (tyInt, tyFloatArray) - impl [arrPtr] = do - len <- getArrayLen arrPtr - makePair (i32, len) (ptr floatArrayStruct, arrPtr) - -deleteFloatArrayDef :: Builtin -deleteFloatArrayDef = - Builtin "deleteFloatArray" args ret impl - where - args = [tyFloatArray] - ret = tyUnit - impl [arrPtr] = do - dataPtr <- getArrayDataPtr arrPtr - _ <- free dataPtr - _ <- free arrPtr - - -- return unit (need to check) - return $ ConstantOperand (C.Struct Nothing False []) diff --git a/src/Language/Granule/Codegen/Builtins/Shared.hs b/src/Language/Granule/Codegen/Builtins/Shared.hs index 2836c58..29c52c5 100644 --- a/src/Language/Granule/Codegen/Builtins/Shared.hs +++ b/src/Language/Granule/Codegen/Builtins/Shared.hs @@ -6,7 +6,7 @@ module Language.Granule.Codegen.Builtins.Shared where import LLVM.AST import LLVM.AST.Type as IR import LLVM.IRBuilder.Constant (bit, int32) -import LLVM.IRBuilder.Instruction (bitcast, call, gep, load, store, insertValue, mul, sext) +import LLVM.IRBuilder.Instruction (bitcast, call, gep, load, store, insertValue) import LLVM.IRBuilder.Module import LLVM.IRBuilder.Monad import Language.Granule.Codegen.Emit.LLVMHelpers (sizeOf) @@ -31,12 +31,6 @@ allocate len ty = do allocateStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => IR.Type -> m Operand allocateStruct ty = allocate (ConstantOperand $ sizeOf ty) ty -allocateFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand -allocateFloatArray len = do - dataSize <- mul len (int32 8) - dataSize64 <- sext dataSize i64 - allocate dataSize64 IR.double - copy :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> Operand -> m Operand copy dst src len = do dst' <- bitcast dst (ptr i8) @@ -55,35 +49,38 @@ makePair (leftTy, leftVal) (rightTy, rightVal) = do pair' <- insertValue pair leftVal [0] insertValue pair' rightVal [1] +writeStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Int -> Operand -> m () +writeStruct struct index value = do + field <- gep struct [int32 0, int32 $ fromIntegral index] + store field 0 value + +readStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Int -> m Operand +readStruct struct index = do + field <- gep struct [int32 0, int32 $ fromIntegral index] + load field 0 + -- Arrays arrayStruct :: IR.Type -> IR.Type arrayStruct ty = StructureType False [i32, ptr ty] -floatArrayStruct :: IR.Type -floatArrayStruct = arrayStruct IR.double - -getArrayLen :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand -getArrayLen arrPtr = do - lenField <- gep arrPtr [int32 0, int32 0] - load lenField 0 - -getArrayDataPtr :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Operand -getArrayDataPtr arrPtr = do - dataField <- gep arrPtr [int32 0, int32 1] - load dataField 0 - -makeArray :: (MonadIRBuilder m, MonadModuleBuilder m) => IR.Type -> Operand -> Operand -> m Operand -makeArray ty len dataPtr = do +-- creates a struct for len and array data +makeArrayStruct :: (MonadIRBuilder m, MonadModuleBuilder m) => IR.Type -> Operand -> Operand -> m Operand +makeArrayStruct ty len dataPtr = do arrPtr <- allocateStruct (arrayStruct ty) + writeStruct arrPtr 0 len + writeStruct arrPtr 1 dataPtr + return arrPtr - lenField <- gep arrPtr [int32 0, int32 0] - store lenField 0 len - - dataField <- gep arrPtr [int32 0, int32 1] - store dataField 0 dataPtr +writeData :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> Operand -> m () +writeData dataPtr index value = do + valuePtr <- gep dataPtr [index] + store valuePtr 0 value - return arrPtr +readData :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> Operand -> m Operand +readData dataPtr index = do + valuePtr <- gep dataPtr [index] + load valuePtr 0 -- Granule types From 76211b08f93b519aee8b85ffc4a0cbb87ac9fdfe Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Mon, 24 Mar 2025 01:41:08 +0000 Subject: [PATCH 19/25] updated cabal file --- granule-compiler.cabal | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 15d5fa7..039c0d0 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -31,8 +31,7 @@ library other-modules: Language.Granule.Codegen.Builtins.Builtins Language.Granule.Codegen.Builtins.Extras - Language.Granule.Codegen.Builtins.ImmutableArray - Language.Granule.Codegen.Builtins.MutableArray + Language.Granule.Codegen.Builtins.FloatArray Language.Granule.Codegen.Builtins.Shared Language.Granule.Codegen.Emit.EmitableDef Language.Granule.Codegen.Emit.EmitBuiltins From 9cffc825871738cf36327ebe51a508eb491b8e79 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Mon, 24 Mar 2025 18:00:40 +0000 Subject: [PATCH 20/25] not relevant --- src/Language/Granule/Codegen/PrintAST.hs | 117 ----------------------- 1 file changed, 117 deletions(-) delete mode 100644 src/Language/Granule/Codegen/PrintAST.hs diff --git a/src/Language/Granule/Codegen/PrintAST.hs b/src/Language/Granule/Codegen/PrintAST.hs deleted file mode 100644 index 83180ca..0000000 --- a/src/Language/Granule/Codegen/PrintAST.hs +++ /dev/null @@ -1,117 +0,0 @@ -module Language.Granule.Codegen.PrintAST where - -import Language.Granule.Syntax.Def -import Language.Granule.Syntax.Expr -import Language.Granule.Syntax.Identifiers (Id (Id)) -import Language.Granule.Syntax.Pattern -import Language.Granule.Syntax.Type -import Data.Bifunctor (bimap) - --- converts an AST to JSON for debugging & visualisation. WIP - -data JSON = Str String | Obj [(String, JSON)] | Arr [JSON] | Pair (JSON, JSON) - -printAST :: AST ev Type -> String -printAST ast = toString (jAST ast) - -toString :: JSON -> String -toString (Str s) = "\"" ++ s ++ "\"" -toString (Obj pairs) = "{" ++ commas (map field pairs) ++ "}" -toString (Arr elems) = "[" ++ commas (map toString elems) ++ "]" -toString (Pair (left, right)) = toString (Arr [left, right]) - -field :: (String, JSON) -> String -field (key, value) = "\"" ++ key ++ "\":" ++ toString value - -commas :: [String] -> String -commas [] = "" -commas [x] = x -commas (x : xs) = x ++ "," ++ commas xs - -node :: String -> [(String, JSON)] -> JSON -node name fields = Obj [(name, Obj fields)] - -jAST :: AST ev Type -> JSON -jAST (AST _ defs _ _ _) = Arr (map jDef defs) - -jDef :: Def ev Type -> JSON -jDef (Def _ _ _ _ equations _) = jEqList equations - -jEqList :: EquationList ev Type -> JSON -jEqList (EquationList _ _ _ equations) = Arr (map jEq equations) - -jEq :: Equation ev Type -> JSON -jEq (Equation s i a b ps expr) = jExpr expr - -jExpr :: Expr ev Type -> JSON -jExpr (App _ ty _ fn arg) = node "App" [("type", jTy ty), ("fn", jExpr fn), ("arg", jExpr arg)] -jExpr (Binop _ ty _ op e1 e2) = node "Binop" [("type", jTy ty), ("op", Str (show op)), ("e1", jExpr e1), ("e2", jExpr e2)] -jExpr (LetDiamond {}) = Str "%LETDIAMOND%" -jExpr (Val _ ty _ val) = node "Val" [("type", jTy ty), ("val", jVal val)] -jExpr (Case s ty b e ps) = node "Case" [("type", jTy ty), ("expr", jExpr e), ("cases", Arr (map (Pair . bimap jPat jExpr) ps))] -jExpr (Hole {}) = Str "%HOLE%" -jExpr (AppTy {}) = Str "%APP_TY%" -jExpr (TryCatch {}) = Str "%TRY_CATCH%" -jExpr (Unpack _ ty _ _ id e1 e2) = node "Unpack" [("type", jTy ty), ("id", jId id), ("e1", jExpr e1), ("e2", jExpr e2)] - - -jVal :: Value ev Type -> JSON -jVal (Var ty id) = node "Var" [("type", jTy ty), ("id", jId id)] -jVal (Abs ty pat mty expr) = node "Abs" [("type", jTy ty), ("pattern", jPat pat), ("expr", jExpr expr)] -jVal (Promote ty expr) = node "Promote" [("type", jTy ty), ("expr", jExpr expr)] -jVal (Pure ty expr) = node "Pure" [("type", jTy ty), ("expr", jExpr expr)] -jVal (Constr ty id vals) = node "Constr" [("type", jTy ty), ("id", jId id), ("vals", Arr (map jVal vals))] -jVal (NumInt n) = Str (show n) -jVal (NumFloat n) = Str (show n) -jVal (CharLiteral c) = Str (show c) -jVal (StringLiteral s) = Str (show s) -jVal (Ext {}) = Str "%EXT%" -jVal (Nec {}) = Str "%NEC%" -jVal (Pack {}) = Str "%PACK%" -jVal (TyAbs {}) = Str "%TY_ABS%" - -jPat :: Pattern Type -> JSON -jPat (PVar _ ty _ id) = node "PVar" [("type", jTy ty), ("id", jId id)] -jPat (PWild s a b) = Str "_" -jPat (PBox _ ty _ p) = node "PBox" [("ty", jTy ty), ("pat", jPat p)] -jPat (PInt s a b v) = Str (show v) -jPat (PFloat s a b v) = Str (show v) -jPat (PConstr _ ty _ id _ pats) = node "PConstr" [("type", jTy ty), ("id", jId id), ("pats", Arr (map jPat pats))] - -jId :: Id -> JSON -jId id = Str (sId id) - -jTy :: Type -> JSON -jTy ty = Str (sTy ty) - -sId :: Id -> String -sId (Id _ id) = id - -paren :: String -> String -paren str = "(" ++ str ++ ")" - -named :: String -> String -> String -named name str = name ++ " " ++ paren str - -sTy :: Type -> String -sTy (Type {}) = "type" -sTy (FunTy id _ arg ret) = sTy arg ++ " -> " ++ sTy ret -sTy (TyCon id) = sId id -sTy (Box _ ty) = "[" ++ sTy ty ++ "]" -sTy (Diamond {}) = "%DIAMOND%" -sTy (Star {}) = "%STAR%" -sTy (Borrow (TyCon (Id "Star" "Star")) ty) = "*" ++ paren (sTy ty) -sTy (Borrow p ty) = "& " ++ paren (sTy p) ++ " " ++ paren (sTy ty) -sTy (TyVar id) = "TyVar (" ++ sId id ++ ")" -sTy (TyApp t1 t2) = "(" ++ sTy t1 ++ ") (" ++ sTy t2 ++ ")" -sTy (TyInt {}) = "%TY_INT%" -sTy (TyRational {}) = "%TY_RATIONAL%" -sTy (TyFraction {}) = "%TY_FRACTION%" -sTy (TyGrade {}) = "%TY_GRADE%" -sTy (TyInfix {}) = "%TY_INFIX%" -sTy (TySet {}) = "%TY_SET%" -sTy (TyCase {}) = "%TY_CASE%" -sTy (TySig {}) = "%TY_SIG%" -sTy (TyExists id kind ty) = "exists {" ++ sId id ++ " : " ++ sTy kind ++ "} . " ++ sTy ty -sTy (TyForall id kind ty) = "forall {" ++ sId id ++ " : " ++ sTy kind ++ "} . " ++ sTy ty -sTy (TyName {}) = "%TY_NAME%" From bac263c855e7dbe320d356817ca72c01c3d0d3e7 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Mon, 24 Mar 2025 18:10:11 +0000 Subject: [PATCH 21/25] unseparation of concerns --- granule-compiler.cabal | 2 - src/Language/Granule/Codegen/RetypeAST.hs | 109 --------------------- src/Language/Granule/Codegen/RewriteAST.hs | 102 ++++++++++++++++++- 3 files changed, 98 insertions(+), 115 deletions(-) delete mode 100644 src/Language/Granule/Codegen/RetypeAST.hs diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 039c0d0..15b0453 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -47,8 +47,6 @@ library Language.Granule.Codegen.Emit.Names Language.Granule.Codegen.Emit.Primitives Language.Granule.Codegen.Emit.Types - Language.Granule.Codegen.PrintAST - Language.Granule.Codegen.RetypeAST Language.Granule.Codegen.RewriteAST Language.Granule.Codegen.StripAST Paths_granule_compiler diff --git a/src/Language/Granule/Codegen/RetypeAST.hs b/src/Language/Granule/Codegen/RetypeAST.hs deleted file mode 100644 index 7933b25..0000000 --- a/src/Language/Granule/Codegen/RetypeAST.hs +++ /dev/null @@ -1,109 +0,0 @@ -module Language.Granule.Codegen.RetypeAST where - -import Data.Bifunctor (bimap) -import Data.List (mapAccumL) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) -import Language.Granule.Syntax.Def -import Language.Granule.Syntax.Expr -import Language.Granule.Syntax.Identifiers (Id) -import Language.Granule.Syntax.Pattern (Pattern (..)) -import Language.Granule.Syntax.Type - --- `let (x, y) = ` inside of an Unpack seems to leave TyVars in the AST, and these --- are not already handled by the compiler. Here we find the correct types and substitute --- the TyVars. WIP. - --- val var -> Type, type var -> Type -type Env = (Map.Map Id Type, Map.Map Id Type) - -emptyEnv :: Env -emptyEnv = (Map.empty, Map.empty) - -insertEnv :: Env -> Either Id Id -> Type -> Env -insertEnv (vals, tys) (Left id) ty = (Map.insert id ty vals, tys) -insertEnv (vals, tys) (Right id) ty = (vals, Map.insert id ty tys) - -lookupEnv :: Env -> Either Id Id -> Maybe Type -lookupEnv (vals, tys) (Left id) = Map.lookup id vals -lookupEnv (vals, tys) (Right id) = Map.lookup id tys - -retypeAST :: AST ev Type -> AST ev Type -retypeAST ast = ast {definitions = map retypeDef (definitions ast)} - where - retypeDef def = def {defEquations = retypeEquationList (defEquations def)} - retypeEquationList eqs = eqs {equations = map retypeEquation (equations eqs)} - retypeEquation eq = eq {equationBody = snd (retypeExpr emptyEnv (equationBody eq))} - -retypeExpr :: Env -> Expr ev Type -> (Env, Expr ev Type) -retypeExpr env (App s ty b e1 e2) = - let (env', e2') = retypeExpr env e2 - (env'', e1') = retypeExpr env' e1 - ty' = subsTy env ty - in (env'', App s ty' b e1' e2') -retypeExpr env (Val s ty b v) = - let (env', v') = retypeVal env v - ty' = subsTy env' ty - in (env', Val s ty' b v') -retypeExpr env exp = error "TODO expr" - -retypeVal :: Env -> Value ev Type -> (Env, Value ev Type) -retypeVal env (Var (TyVar id) var) = - -- see if we already have it - case lookupEnv env (Right id) of - Just ty -> (env, Var ty var) - Nothing -> - -- see if the value variable has it - case lookupEnv env (Left var) of - -- and update - Just ty -> (insertEnv env (Right id) ty, Var ty var) - -- we wont always win - Nothing -> (env, Var (TyVar id) var) -retypeVal env (Var ty var) = (insertEnv env (Left var) ty, Var ty var) -retypeVal env (Abs ty p mt e) = - let (env', p') = retypePat env p - (env'', e') = retypeExpr env' e - ty' = subsTy env'' ty - in (env'', Abs ty' p' mt e') -retypeVal env (Constr ty id vals) = - let (env', vals') = mapAccumL retypeVal env vals - ty' = subsTy env' ty - in (env', Constr ty' id vals') -retypeVal env (NumInt v) = (env, NumInt v) -retypeVal env (NumFloat v) = (env, NumFloat v) -retypeVal env (Promote t v) = (env, Promote t v) -retypeVal env val = error "TODO val" - -retypePat :: Env -> Pattern Type -> (Env, Pattern Type) -retypePat env (PVar s (TyVar id) b var) = - case lookupEnv env (Right id) of - Just ty -> (env, PVar s ty b var) - Nothing -> - case lookupEnv env (Left var) of - Just ty -> (insertEnv env (Right id) ty, PVar s ty b var) - Nothing -> (env, PVar s (TyVar id) b var) -retypePat env (PVar s ty b var) = (insertEnv env (Left var) ty, PVar s ty b var) -retypePat env (PConstr s ty b id ids ps) = - let (env', ps') = mapAccumL retypePat env ps - ty' = subsTy env' ty - in (env', PConstr s ty' b id ids ps') -retypePat env p = error "TODO pat" - -subsTy :: Env -> Type -> Type -subsTy env (TyVar id) = fromMaybe (TyVar id) (lookupEnv env (Right id)) -subsTy env (Type i) = Type i -subsTy env (FunTy id mc arg ret) = FunTy id mc (subsTy env arg) (subsTy env ret) -subsTy env (TyCon id) = TyCon id -subsTy env (Box c t) = subsTy env t -subsTy env (Diamond e t) = Diamond (subsTy env e) (subsTy env t) -subsTy env (Star g t) = subsTy env t -subsTy env (Borrow p t) = subsTy env t -subsTy env (TyApp t1 t2) = TyApp (subsTy env t1) (subsTy env t2) -subsTy env (TyGrade mt i) = TyGrade mt i -subsTy env (TyInfix op t1 t2) = TyInfix op (subsTy env t1) (subsTy env t2) -subsTy env (TySet p ts) = TySet p (map (subsTy env) ts) -subsTy env (TyCase t tps) = TyCase (subsTy env t) (map (bimap (subsTy env) (subsTy env)) tps) -subsTy env (TySig t k) = TySig (subsTy env t) (subsTy env k) -subsTy env (TyExists id k t) = subsTy env t -subsTy env (TyForall id k t) = subsTy env t -subsTy env t = t diff --git a/src/Language/Granule/Codegen/RewriteAST.hs b/src/Language/Granule/Codegen/RewriteAST.hs index f3f52fe..32f710b 100644 --- a/src/Language/Granule/Codegen/RewriteAST.hs +++ b/src/Language/Granule/Codegen/RewriteAST.hs @@ -1,10 +1,14 @@ module Language.Granule.Codegen.RewriteAST where +import Data.Bifunctor (bimap) +import Data.List (mapAccumL) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) import Language.Granule.Syntax.Def import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Identifiers (Id) import Language.Granule.Syntax.Pattern import Language.Granule.Syntax.Type -import Language.Granule.Codegen.RetypeAST -- Rewrite Unpack ASTs into App Abs ASTs which our -- compiler already knows how to handle. WIP. @@ -23,11 +27,10 @@ rewriteExpr (Unpack s retTy b tyVar var e1 e2) = e2' = e2 absTy = FunTy Nothing Nothing e1Ty retTy in fixTypes (App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2')) e1') - where - fixTypes expr = snd $ retypeExpr emptyEnv expr + where + fixTypes expr = snd $ substExpr emptyEnv expr rewriteExpr exp = exp - exprTy :: Expr ev Type -> Type exprTy (App _ ty _ _ _) = ty exprTy (Val _ ty _ _) = ty @@ -38,3 +41,94 @@ exprTy (Hole _ ty _ _ _) = ty exprTy (AppTy _ ty _ _ _) = ty exprTy (TryCatch _ ty _ _ _ _ _ _) = ty exprTy (Unpack _ ty _ _ _ _ _) = ty + +-- `let (x, y) = ` inside of an Unpack seems to leave TyVars in the AST, and these +-- are not already handled by the compiler. Here we find the correct types and substitute +-- the TyVars. WIP. + +-- val var -> Type, type var -> Type +type Env = (Map.Map Id Type, Map.Map Id Type) + +emptyEnv :: Env +emptyEnv = (Map.empty, Map.empty) + +insertEnv :: Env -> Either Id Id -> Type -> Env +insertEnv (vals, tys) (Left id) ty = (Map.insert id ty vals, tys) +insertEnv (vals, tys) (Right id) ty = (vals, Map.insert id ty tys) + +lookupEnv :: Env -> Either Id Id -> Maybe Type +lookupEnv (vals, tys) (Left id) = Map.lookup id vals +lookupEnv (vals, tys) (Right id) = Map.lookup id tys + +substExpr :: Env -> Expr ev Type -> (Env, Expr ev Type) +substExpr env (App s ty b e1 e2) = + let (env', e2') = substExpr env e2 + (env'', e1') = substExpr env' e1 + ty' = substTy env ty + in (env'', App s ty' b e1' e2') +substExpr env (Val s ty b v) = + let (env', v') = substVal env v + ty' = substTy env' ty + in (env', Val s ty' b v') +substExpr env exp = error "TODO expr" + +substVal :: Env -> Value ev Type -> (Env, Value ev Type) +substVal env (Var (TyVar id) var) = + -- see if we already have it + case lookupEnv env (Right id) of + Just ty -> (env, Var ty var) + Nothing -> + -- see if the value variable has it + case lookupEnv env (Left var) of + -- and update + Just ty -> (insertEnv env (Right id) ty, Var ty var) + -- we wont always win + Nothing -> (env, Var (TyVar id) var) +substVal env (Var ty var) = (insertEnv env (Left var) ty, Var ty var) +substVal env (Abs ty p mt e) = + let (env', p') = substPat env p + (env'', e') = substExpr env' e + ty' = substTy env'' ty + in (env'', Abs ty' p' mt e') +substVal env (Constr ty id vals) = + let (env', vals') = mapAccumL substVal env vals + ty' = substTy env' ty + in (env', Constr ty' id vals') +substVal env (NumInt v) = (env, NumInt v) +substVal env (NumFloat v) = (env, NumFloat v) +substVal env (Promote t v) = (env, Promote t v) +substVal env val = error "TODO val" + +substPat :: Env -> Pattern Type -> (Env, Pattern Type) +substPat env (PVar s (TyVar id) b var) = + case lookupEnv env (Right id) of + Just ty -> (env, PVar s ty b var) + Nothing -> + case lookupEnv env (Left var) of + Just ty -> (insertEnv env (Right id) ty, PVar s ty b var) + Nothing -> (env, PVar s (TyVar id) b var) +substPat env (PVar s ty b var) = (insertEnv env (Left var) ty, PVar s ty b var) +substPat env (PConstr s ty b id ids ps) = + let (env', ps') = mapAccumL substPat env ps + ty' = substTy env' ty + in (env', PConstr s ty' b id ids ps') +substPat env p = error "TODO pat" + +substTy :: Env -> Type -> Type +substTy env (TyVar id) = fromMaybe (TyVar id) (lookupEnv env (Right id)) +substTy env (Type i) = Type i +substTy env (FunTy id mc arg ret) = FunTy id mc (substTy env arg) (substTy env ret) +substTy env (TyCon id) = TyCon id +substTy env (Box c t) = substTy env t +substTy env (Diamond e t) = Diamond (substTy env e) (substTy env t) +substTy env (Star g t) = substTy env t +substTy env (Borrow p t) = substTy env t +substTy env (TyApp t1 t2) = TyApp (substTy env t1) (substTy env t2) +substTy env (TyGrade mt i) = TyGrade mt i +substTy env (TyInfix op t1 t2) = TyInfix op (substTy env t1) (substTy env t2) +substTy env (TySet p ts) = TySet p (map (substTy env) ts) +substTy env (TyCase t tps) = TyCase (substTy env t) (map (bimap (substTy env) (substTy env)) tps) +substTy env (TySig t k) = TySig (substTy env t) (substTy env k) +substTy env (TyExists id k t) = substTy env t +substTy env (TyForall id k t) = substTy env t +substTy env t = t From 38517a439330e876e9382ca8f4249047c2ed07e9 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Mon, 24 Mar 2025 18:10:16 +0000 Subject: [PATCH 22/25] clarification --- src/Language/Granule/Codegen/StripAST.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Language/Granule/Codegen/StripAST.hs b/src/Language/Granule/Codegen/StripAST.hs index b4c991d..bd77fdc 100644 --- a/src/Language/Granule/Codegen/StripAST.hs +++ b/src/Language/Granule/Codegen/StripAST.hs @@ -6,10 +6,10 @@ import Language.Granule.Syntax.Expr import Language.Granule.Syntax.Pattern import Language.Granule.Syntax.Type --- Strips types which are not currently needed (or handled) by --- the compiler, to make life easier and debugging simpler. We are --- stripping Box, Star, Borrow and type quantifiers, but we may --- wish to reinstate these to help with future optimisation. WIP. +-- Temporarily strip types which are not currently needed (or handled) +-- by the compiler, to make life easier and debugging simpler. We are +-- stripping Box, Star, Borrow and type quantifiers, but we will wish +-- to reinstate these later to help with optimisation. stripAST :: AST ev Type -> AST ev Type stripAST (AST decls defs imports hidden name) = From a95e32bd1846348c2877e0d72e94af8dcc3657e7 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Tue, 1 Apr 2025 14:36:23 +0100 Subject: [PATCH 23/25] fix test and add bounds check --- .../Granule/Codegen/Builtins/FloatArray.hs | 37 ++++++++++++++++--- tests/golden/positive/unpack.gr | 2 +- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/Language/Granule/Codegen/Builtins/FloatArray.hs b/src/Language/Granule/Codegen/Builtins/FloatArray.hs index b6d8acb..9a803e4 100644 --- a/src/Language/Granule/Codegen/Builtins/FloatArray.hs +++ b/src/Language/Granule/Codegen/Builtins/FloatArray.hs @@ -3,13 +3,17 @@ module Language.Granule.Codegen.Builtins.FloatArray where import qualified LLVM.AST.Constant as C +import LLVM.AST.IntegerPredicate import LLVM.AST.Operand (Operand (ConstantOperand)) import LLVM.AST.Type as IR +import LLVM.IRBuilder (emitBlockStart) import LLVM.IRBuilder.Constant as C import LLVM.IRBuilder.Instruction import LLVM.IRBuilder.Module (MonadModuleBuilder) -import LLVM.IRBuilder.Monad (MonadIRBuilder) +import LLVM.IRBuilder.Monad (MonadIRBuilder, freshName) import Language.Granule.Codegen.Builtins.Shared +import Language.Granule.Codegen.Emit.Primitives (trap) +import Prelude hiding (or) -- Mutable FloatArray builtins newFloatArrayDef, readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, deleteFloatArrayDef :: Builtin @@ -20,11 +24,11 @@ newFloatArrayDef = readFloatArrayDef = Builtin "readFloatArray" [tyFloatArray, tyInt] (tyPair (tyFloat, tyFloatArray)) impl where - impl [arrPtr, idx] = readFloatArray arrPtr idx + impl [arrPtr, idx] = withBoundsCheck arrPtr idx $ readFloatArray arrPtr idx writeFloatArrayDef = Builtin "writeFloatArray" [tyFloatArray, tyInt, tyFloat] tyFloatArray impl where - impl [arrPtr, idx, val] = do + impl [arrPtr, idx, val] = withBoundsCheck arrPtr idx $ do dataPtr <- readStruct arrPtr 1 writeData dataPtr idx val return arrPtr @@ -51,11 +55,11 @@ newFloatArrayIDef = readFloatArrayIDef = Builtin "readFloatArrayI" [tyFloatArray, tyInt] (tyPair (tyFloat, tyFloatArray)) impl where - impl [arrPtr, idx] = readFloatArray arrPtr idx + impl [arrPtr, idx] = withBoundsCheck arrPtr idx $ readFloatArray arrPtr idx writeFloatArrayIDef = Builtin "writeFloatArrayI" [tyFloatArray, tyInt, tyFloat] tyFloatArray impl where - impl [arrPtr, idx, val] = do + impl [arrPtr, idx, val] = withBoundsCheck arrPtr idx $ do len <- readStruct arrPtr 0 dataPtr <- readStruct arrPtr 1 @@ -104,3 +108,26 @@ lengthFloatArray :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> m Ope lengthFloatArray arrPtr = do len <- readStruct arrPtr 0 makePair (i32, len) (ptr floatArrayStruct, arrPtr) + +withBoundsCheck :: + (MonadIRBuilder m, MonadModuleBuilder m) => + Operand -> + Operand -> + m Operand -> + m Operand +withBoundsCheck arrPtr idx continuation = do + len <- readStruct arrPtr 0 + ltZero <- icmp SLT idx (int32 0) + gteLen <- icmp SGE idx len + outOfBounds <- or ltZero gteLen + + abort <- freshName "out_of_bounds" + continue <- freshName "in_bounds" + + condBr outOfBounds abort continue + + emitBlockStart abort + trap + + emitBlockStart continue + continuation diff --git a/tests/golden/positive/unpack.gr b/tests/golden/positive/unpack.gr index 59ffe55..db537f9 100644 --- a/tests/golden/positive/unpack.gr +++ b/tests/golden/positive/unpack.gr @@ -1,6 +1,6 @@ main : (Float, Float) main = - unpack = newFloatArray 1 in + unpack = newFloatArray 2 in let arr' = writeFloatArray arr 0 42.0 in let arr'' = writeFloatArray arr' 1 100.0 in let (x, arr''') = readFloatArray arr'' 0 in From 2cb3e80c2b5d0885da0444eec07f4b47cfbacd60 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Tue, 1 Apr 2025 14:41:30 +0100 Subject: [PATCH 24/25] remove type stripping --- granule-compiler.cabal | 1 - src/Language/Granule/Codegen/Compile.hs | 4 +- .../Granule/Codegen/Emit/LowerType.hs | 2 +- src/Language/Granule/Codegen/StripAST.hs | 110 ------------------ 4 files changed, 2 insertions(+), 115 deletions(-) delete mode 100644 src/Language/Granule/Codegen/StripAST.hs diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 15b0453..450072e 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -48,7 +48,6 @@ library Language.Granule.Codegen.Emit.Primitives Language.Granule.Codegen.Emit.Types Language.Granule.Codegen.RewriteAST - Language.Granule.Codegen.StripAST Paths_granule_compiler hs-source-dirs: src diff --git a/src/Language/Granule/Codegen/Compile.hs b/src/Language/Granule/Codegen/Compile.hs index 101a52d..892c241 100644 --- a/src/Language/Granule/Codegen/Compile.hs +++ b/src/Language/Granule/Codegen/Compile.hs @@ -9,14 +9,12 @@ import Language.Granule.Codegen.ConvertClosures import Language.Granule.Codegen.Emit.EmitLLVM import Language.Granule.Codegen.MarkGlobals import Language.Granule.Codegen.RewriteAST -import Language.Granule.Codegen.StripAST import qualified LLVM.AST as IR compile :: String -> AST () Type -> Either String IR.Module compile moduleName typedAST = - let stripped = stripAST typedAST - rewritten = rewriteAST stripped + let rewritten = rewriteAST typedAST normalised = normaliseDefinitions rewritten markedGlobals = markGlobals normalised (Ok topsorted) = topologicallySortDefinitions markedGlobals diff --git a/src/Language/Granule/Codegen/Emit/LowerType.hs b/src/Language/Granule/Codegen/Emit/LowerType.hs index 325097a..b8ec373 100644 --- a/src/Language/Granule/Codegen/Emit/LowerType.hs +++ b/src/Language/Granule/Codegen/Emit/LowerType.hs @@ -55,7 +55,7 @@ llvmType (TyCon (MkId "Handle")) = i8 llvmType (TyCon (MkId "Bool")) = i1 llvmType (Box coeffect ty) = llvmType ty llvmType (TyExists _ _ ty) = llvmType ty -llvmType (Borrow (TyCon (MkId "Star")) ty) = llvmType ty +llvmType (Borrow _ ty) = llvmType ty llvmType ty = error $ "Cannot lower the type " ++ show ty llvmTopLevelType :: GrType -> IrType diff --git a/src/Language/Granule/Codegen/StripAST.hs b/src/Language/Granule/Codegen/StripAST.hs deleted file mode 100644 index bd77fdc..0000000 --- a/src/Language/Granule/Codegen/StripAST.hs +++ /dev/null @@ -1,110 +0,0 @@ -module Language.Granule.Codegen.StripAST where - -import Data.Bifunctor (Bifunctor (second), bimap) -import Language.Granule.Syntax.Def -import Language.Granule.Syntax.Expr -import Language.Granule.Syntax.Pattern -import Language.Granule.Syntax.Type - --- Temporarily strip types which are not currently needed (or handled) --- by the compiler, to make life easier and debugging simpler. We are --- stripping Box, Star, Borrow and type quantifiers, but we will wish --- to reinstate these later to help with optimisation. - -stripAST :: AST ev Type -> AST ev Type -stripAST (AST decls defs imports hidden name) = - AST decls (map stripDef defs) imports hidden name - -stripDef :: Def ev Type -> Def ev Type -stripDef (Def s i b spec el ts) = - Def s i b spec (stripEquationList el) (stripTypeScheme ts) - -stripEquationList :: EquationList ev Type -> EquationList ev Type -stripEquationList (EquationList s v b es) = - EquationList s v b (map stripEquation es) - -stripEquation :: Equation ev Type -> Equation ev Type -stripEquation (Equation s n a b ps e) = - Equation s n (stripTy a) b (map stripPat ps) (stripExpr e) - -stripExpr :: Expr ev Type -> Expr ev Type -stripExpr (App s a b e1 e2) = - App s (stripTy a) b (stripExpr e1) (stripExpr e2) -stripExpr (Binop s a b op e1 e2) = - Binop s (stripTy a) b op (stripExpr e1) (stripExpr e2) -stripExpr (LetDiamond s a b p mt e1 e2) = - LetDiamond s (stripTy a) b (stripPat p) (stripMaybeTy mt) (stripExpr e1) (stripExpr e2) -stripExpr (Val s a b v) = - Val s (stripTy a) b (stripVal v) -stripExpr (Case s a b e pes) = - Case s (stripTy a) b (stripExpr e) (map (bimap stripPat stripExpr) pes) -stripExpr (Hole s a b ids hints) = - Hole s (stripTy a) b ids hints -stripExpr (AppTy s a b e t) = - AppTy s (stripTy a) b (stripExpr e) (stripTy t) -stripExpr (TryCatch s a b e1 p mt e2 e3) = - TryCatch s (stripTy a) b (stripExpr e1) (stripPat p) (stripMaybeTy mt) (stripExpr e2) (stripExpr e3) -stripExpr (Unpack s a b tyVar var e1 e2) = - Unpack s (stripTy a) b tyVar var (stripExpr e1) (stripExpr e2) - -stripVal :: Value ev Type -> Value ev Type -stripVal (Var a id) = - Var (stripTy a) id -stripVal (Abs a p mt e) = - Abs (stripTy a) (stripPat p) (stripMaybeTy mt) (stripExpr e) -stripVal (Promote a e) = - Promote (stripTy a) (stripExpr e) -stripVal (Pure a e) = - Pure (stripTy a) (stripExpr e) -stripVal (Constr a id vs) = - Constr (stripTy a) id (map stripVal vs) -stripVal (Ext a ev) = - Ext (stripTy a) ev -stripVal (Nec a e) = - Nec (stripTy a) (stripExpr e) -stripVal (Pack s a t e id k t') = - Pack s (stripTy a) (stripTy t) (stripExpr e) id (stripTy k) (stripTy t') -stripVal (TyAbs a (Left (id, t)) e) = - TyAbs (stripTy a) (Left (id, stripTy t)) (stripExpr e) -stripVal (TyAbs a (Right ids) e) = - TyAbs (stripTy a) (Right ids) (stripExpr e) -stripVal v = v - -stripPat :: Pattern Type -> Pattern Type -stripPat (PVar s a b v) = PVar s (stripTy a) b v -stripPat (PWild s a b) = PWild s (stripTy a) b -stripPat (PBox s a b p) = stripPat p -stripPat (PInt s a b i) = PInt s (stripTy a) b i -stripPat (PFloat s a b f) = PFloat s (stripTy a) b f -stripPat (PConstr s a b id ids ps) = PConstr s (stripTy a) b id ids (map stripPat ps) - -stripTypeScheme :: TypeScheme -> TypeScheme -stripTypeScheme (Forall s quants constraints t) = - Forall - s - (map (second stripTy) quants) - (map stripTy constraints) - (stripTy t) - -stripMaybeTy :: Maybe Type -> Maybe Type -stripMaybeTy Nothing = Nothing -stripMaybeTy (Just ty) = Just (stripTy ty) - -stripTy :: Type -> Type -stripTy (Type i) = Type i -stripTy (FunTy id mc arg ret) = FunTy id (stripMaybeTy mc) (stripTy arg) (stripTy ret) -stripTy (TyCon id) = TyCon id -stripTy (Box c t) = stripTy t -stripTy (Diamond e t) = Diamond (stripTy e) (stripTy t) -stripTy (Star g t) = stripTy t -stripTy (Borrow p t) = stripTy t -stripTy (TyVar id) = TyVar id -stripTy (TyApp t1 t2) = TyApp (stripTy t1) (stripTy t2) -stripTy (TyGrade mt i) = TyGrade (stripMaybeTy mt) i -stripTy (TyInfix op t1 t2) = TyInfix op (stripTy t1) (stripTy t2) -stripTy (TySet p ts) = TySet p (map stripTy ts) -stripTy (TyCase t tps) = TyCase (stripTy t) (map (bimap stripTy stripTy) tps) -stripTy (TySig t k) = TySig (stripTy t) (stripTy k) -stripTy (TyExists id k t) = stripTy t -stripTy (TyForall id k t) = stripTy t -stripTy t = t From aabfa3aa9fefd632b9def3f8f74d9ebe47300316 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Tue, 1 Apr 2025 14:42:09 +0100 Subject: [PATCH 25/25] no longer in use --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index fa56521..48f0299 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ .stack-work/ stack.yaml.lock -.tmp/