Skip to content

Commit 4a28127

Browse files
committed
feat: support direct recursion sugar for concrete sumtypes
Enables users to use "direct" recursion on sumtypes and abstracts away pointers in function signatures such as case initers.
1 parent f35761f commit 4a28127

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

src/Deftype.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ templatizeTy (VarTy vt) = VarTy ("$" ++ vt)
465465
templatizeTy (FuncTy argTys retTy ltTy) = FuncTy (map templatizeTy argTys) (templatizeTy retTy) (templatizeTy ltTy)
466466
templatizeTy (StructTy name tys) = StructTy name (map templatizeTy tys)
467467
templatizeTy (RefTy t lt) = RefTy (templatizeTy t) (templatizeTy lt)
468+
templatizeTy (RecTy t) = t
468469
templatizeTy (PointerTy t) = PointerTy (templatizeTy t)
469470
templatizeTy t = t
470471

src/Emit.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ defSumtypeToDeclaration sumTy@(StructTy _ _) rest =
867867
appendToSrc ("struct " ++ tyToC sumTy ++ " {\n")
868868
else appendToSrc "typedef struct {\n")
869869
appendToSrc (addIndent indent ++ "union {\n")
870-
mapM_ (emitSumtypeCase indent) rest
870+
mapM_ (emitSumtypeCase indent) pointerfix
871871
appendToSrc (addIndent indent ++ "char __dummy;\n")
872872
appendToSrc (addIndent indent ++ "} u;\n")
873873
appendToSrc (addIndent indent ++ "char _tag;\n")
@@ -876,7 +876,7 @@ defSumtypeToDeclaration sumTy@(StructTy _ _) rest =
876876
(appendToSrc (" " ++ tyToC sumTy))
877877
appendToSrc ";\n"
878878
--appendToSrc ("// " ++ show typeVariables ++ "\n")
879-
mapM_ emitSumtypeCaseTagDefinition (zip [0 ..] rest)
879+
mapM_ emitSumtypeCaseTagDefinition (zip [0 ..] pointerfix)
880880
emitSumtypeCase :: Int -> XObj -> State EmitterState ()
881881
emitSumtypeCase ind (XObj (Lst [XObj (Sym (SymPath [] caseName) _) _ _, XObj (Arr []) _ _]) _ _) =
882882
appendToSrc (addIndent ind ++ "// " ++ caseName ++ "\n")

src/Sumtypes.hs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ moduleForSumtype innerEnv typeEnv env pathStrings typeName typeVariables rest i
5656
candidate = TypeCandidate {typename = typeName, variables = typeVariables, restriction = AllowOnlyNamesInScope, typemembers = rest, interfaceConstraints = [], candidateTypeEnv = typeEnv, candidateEnv = env}
5757
in do
5858
let structTy = StructTy (ConcreteNameTy (SymPath pathStrings typeName)) typeVariables
59+
ptrFix = map (recursiveMembersToPointers structTy) rest
5960
okRecursive candidate
60-
cases <- toCases typeEnv env candidate
61+
cases <- toCases typeEnv env (candidate {typemembers = ptrFix})
6162
okIniters <- initers insidePath structTy cases
6263
okTag <- binderForTag insidePath structTy
6364
(okStr, okStrDeps) <- binderForStrOrPrn typeEnv env insidePath structTy cases "str"
@@ -91,19 +92,21 @@ binderForCaseInit _ _ _ = error "binderforcaseinit"
9192

9293
concreteCaseInit :: AllocationMode -> [String] -> Ty -> SumtypeCase -> (String, Binder)
9394
concreteCaseInit allocationMode insidePath structTy sumtypeCase =
94-
instanceBinder (SymPath insidePath (caseName sumtypeCase)) (FuncTy (caseTys sumtypeCase) structTy StaticLifetimeTy) template doc
95+
instanceBinder (SymPath insidePath (caseName sumtypeCase)) (FuncTy (map removeRec (caseTys sumtypeCase)) structTy StaticLifetimeTy) template doc
9596
where
9697
doc = "creates a `" ++ caseName sumtypeCase ++ "`."
9798
template =
9899
Template
99-
(FuncTy (caseTys sumtypeCase) (VarTy "p") StaticLifetimeTy)
100+
(FuncTy (map removeRec (caseTys sumtypeCase)) (VarTy "p") StaticLifetimeTy)
100101
( \(FuncTy _ concreteStructTy _) ->
101102
let mappings = unifySignatures structTy concreteStructTy
102-
correctedTys = map (replaceTyVars mappings) (caseTys sumtypeCase)
103+
correctedTys = map (replaceTyVars mappings) (map removeRec (caseTys sumtypeCase))
103104
in (toTemplate $ "$p $NAME(" ++ joinWithComma (zipWith (curry memberArg) anonMemberNames (remove isUnit correctedTys)) ++ ")")
104105
)
105106
(const (tokensForCaseInit allocationMode structTy sumtypeCase))
106107
(\FuncTy {} -> [])
108+
removeRec (RecTy t) = t
109+
removeRec t = t
107110

108111
genericCaseInit :: AllocationMode -> [String] -> Ty -> SumtypeCase -> (String, Binder)
109112
genericCaseInit allocationMode pathStrings originalStructTy sumtypeCase =
@@ -141,13 +144,15 @@ tokensForCaseInit allocationMode sumTy@(StructTy (ConcreteNameTy _) _) sumtypeCa
141144
StackAlloc -> " $p instance;"
142145
HeapAlloc -> " $p instance = CARP_MALLOC(sizeof(" ++ show sumTy ++ "));",
143146
joinLines $ caseMemberAssignment allocationMode correctedName . fst <$> unitless,
147+
joinLines $ recCaseMemberAssignment allocationMode correctedName sumTy . fst <$> recursive,
144148
" instance._tag = " ++ tagName sumTy correctedName ++ ";",
145149
" return instance;",
146150
"}"
147151
]
148152
where
149153
correctedName = caseName sumtypeCase
150-
unitless = zip anonMemberNames $ remove isUnit (caseTys sumtypeCase)
154+
unitless = remove (isRecType . snd) $ zip anonMemberNames $ remove isUnit (caseTys sumtypeCase)
155+
recursive = filter (isRecType . snd) $ zip anonMemberNames (caseTys sumtypeCase)
151156
tokensForCaseInit _ _ _ = error "tokensforcaseinit"
152157

153158
caseMemberAssignment :: AllocationMode -> String -> String -> String
@@ -158,6 +163,15 @@ caseMemberAssignment allocationMode caseNm memberName =
158163
StackAlloc -> ".u."
159164
HeapAlloc -> "->u."
160165

166+
recCaseMemberAssignment :: AllocationMode -> String -> Ty -> String -> String
167+
recCaseMemberAssignment allocationMode caseNm sumTy memberName =
168+
" instance" ++ sep ++ caseNm ++ "." ++ memberName ++ " = CARP_MALLOC(sizeof(" ++ show sumTy ++ "));\n"
169+
++ " *instance" ++ sep ++ caseNm ++ "." ++ memberName ++ " = " ++ memberName ++ ";"
170+
where
171+
sep = case allocationMode of
172+
StackAlloc -> ".u."
173+
HeapAlloc -> "->u."
174+
161175
binderForTag :: [String] -> Ty -> Either TypeError (String, Binder)
162176
binderForTag insidePath originalStructTy@(StructTy (ConcreteNameTy _) _) =
163177
Right $ instanceBinder path (FuncTy [RefTy originalStructTy (VarTy "q")] IntTy StaticLifetimeTy) template doc

0 commit comments

Comments
 (0)