@@ -9,19 +9,19 @@ import hkmc2.semantics.*
99import hkmc2 .syntax .Tree
1010
1111class StackSafeTransform (depthLimit : Int )(using State ):
12- private val STACK_LIMIT_IDENT : Tree .Ident = Tree .Ident (" __stackLimit " )
13- private val STACK_DEPTH_IDENT : Tree .Ident = Tree .Ident (" __stackDepth " )
14- private val STACK_OFFSET_IDENT : Tree .Ident = Tree .Ident (" __stackOffset " )
15- private val STACK_HANDLER_IDENT : Tree .Ident = Tree .Ident (" __stackHandler " )
16-
17- private val predefPath : Path = State .globalThisSymbol .asPath.selN( Tree . Ident ( " Predef " ))
18- private val checkDepthPath : Path = predefPath .selN(Tree .Ident (" checkDepth" ))
19- private val resetDepthPath : Path = predefPath .selN(Tree .Ident (" resetDepth" ))
20- private val stackDelayClsPath : Path = predefPath .selN(Tree .Ident (" __StackDelay " ))
21- private val stackLimitPath : Path = predefPath .selN(STACK_LIMIT_IDENT )
22- private val stackDepthPath : Path = predefPath .selN(STACK_DEPTH_IDENT )
23- private val stackOffsetPath : Path = predefPath .selN(STACK_OFFSET_IDENT )
24- private val stackHandlerPath : Path = predefPath .selN(STACK_HANDLER_IDENT )
12+ private val STACK_LIMIT_IDENT : Tree .Ident = Tree .Ident (" stackLimit " )
13+ private val STACK_DEPTH_IDENT : Tree .Ident = Tree .Ident (" stackDepth " )
14+ private val STACK_OFFSET_IDENT : Tree .Ident = Tree .Ident (" stackOffset " )
15+ private val STACK_HANDLER_IDENT : Tree .Ident = Tree .Ident (" stackHandler " )
16+
17+ private val runtimePath : Path = State .runtimeSymbol .asPath
18+ private val checkDepthPath : Path = runtimePath .selN(Tree .Ident (" checkDepth" ))
19+ private val resetDepthPath : Path = runtimePath .selN(Tree .Ident (" resetDepth" ))
20+ private val stackDelayClsPath : Path = runtimePath .selN(Tree .Ident (" StackDelay " ))
21+ private val stackLimitPath : Path = runtimePath .selN(STACK_LIMIT_IDENT )
22+ private val stackDepthPath : Path = runtimePath .selN(STACK_DEPTH_IDENT )
23+ private val stackOffsetPath : Path = runtimePath .selN(STACK_OFFSET_IDENT )
24+ private val stackHandlerPath : Path = runtimePath .selN(STACK_HANDLER_IDENT )
2525
2626 private def intLit (n : BigInt ) = Value .Lit (Tree .IntLit (n))
2727
@@ -33,22 +33,20 @@ class StackSafeTransform(depthLimit: Int)(using State):
3333 def extractRes (res : Result , isTailCall : Bool , f : Result => Block , sym : Option [Symbol ], curDepth : => Symbol ) =
3434 if isTailCall then
3535 blockBuilder
36- .assignFieldN(predefPath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
36+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
3737 .ret(res)
3838 else
3939 val tmp = sym getOrElse TempSymbol (None , " tmp" )
4040 val offsetGtDepth = TempSymbol (None , " offsetGtDepth" )
4141 blockBuilder
42- .assignFieldN(predefPath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
42+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
4343 .assign(tmp, res)
4444 .assign(tmp, Call (resetDepthPath, tmp.asPath.asArg :: curDepth.asPath.asArg :: Nil )(true , false ))
4545 .rest(f(tmp.asPath))
46-
47- def extractResTopLevel ( res : Result , isTailCall : Bool , f : Result => Block , sym : Option [ Symbol ], curDepth : => Symbol ) =
46+
47+ def wrapStackSafe ( body : Block , resSym : Local , rest : Block ) =
4848 val resumeSym = VarSymbol (Tree .Ident (" resume" ))
4949 val handlerSym = TempSymbol (None , " stackHandler" )
50- val resSym = sym getOrElse TempSymbol (None , " res" )
51- val handlerRes = TempSymbol (None , " res" )
5250
5351 val clsSym = ClassSymbol (
5452 Tree .TypeDef (syntax.Cls , Tree .Error (), N , N ),
@@ -64,26 +62,28 @@ class StackSafeTransform(depthLimit: Int)(using State):
6462 /*
6563 fun perform() =
6664 stackOffset = stackDepth
67- let ret = resume()
68- ret
65+ resume()
6966 */
7067 blockBuilder
71- .assignFieldN(predefPath, STACK_OFFSET_IDENT , stackDepthPath)
72- .assign(handlerRes, Call (Value .Ref (resumeSym), Nil )(true , true ))
73- .ret(handlerRes.asPath)
68+ .assignFieldN(runtimePath, STACK_OFFSET_IDENT , stackDepthPath)
69+ .ret(Call (Value .Ref (resumeSym), Nil )(true , true ))
7470 ) :: Nil ,
7571 blockBuilder
76- .assignFieldN(predefPath , STACK_LIMIT_IDENT , intLit(depthLimit)) // set stackLimit before call
77- .assignFieldN(predefPath , STACK_OFFSET_IDENT , intLit(0 )) // set stackOffset = 0 before call
78- .assignFieldN(predefPath , STACK_DEPTH_IDENT , intLit(1 )) // set stackDepth = 1 before call
79- .assignFieldN(predefPath , STACK_HANDLER_IDENT , handlerSym.asPath) // assign stack handler
80- .rest(HandleBlockReturn (res) ),
72+ .assignFieldN(runtimePath , STACK_LIMIT_IDENT , intLit(depthLimit)) // set stackLimit before call
73+ .assignFieldN(runtimePath , STACK_OFFSET_IDENT , intLit(0 )) // set stackOffset = 0 before call
74+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , intLit(1 )) // set stackDepth = 1 before call
75+ .assignFieldN(runtimePath , STACK_HANDLER_IDENT , handlerSym.asPath) // assign stack handler
76+ .rest(body ),
8177 blockBuilder // reset the stack safety values
82- .assignFieldN(predefPath , STACK_DEPTH_IDENT , intLit(0 )) // set stackDepth = 0 after call
83- .assignFieldN(predefPath , STACK_HANDLER_IDENT , Value .Lit (Tree .UnitLit (true ))) // set stackHandler = null
84- .rest(f(resSym.asPath) )
78+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , intLit(0 )) // set stackDepth = 0 after call
79+ .assignFieldN(runtimePath , STACK_HANDLER_IDENT , Value .Lit (Tree .UnitLit (true ))) // set stackHandler = null
80+ .rest(rest )
8581 )
8682
83+ def extractResTopLevel (res : Result , isTailCall : Bool , f : Result => Block , sym : Option [Symbol ], curDepth : => Symbol ) =
84+ val resSym = sym getOrElse TempSymbol (None , " res" )
85+ wrapStackSafe(HandleBlockReturn (res), resSym, f(resSym.asPath))
86+
8787 // Rewrites anything that can contain a Call to increase the stack depth
8888 def transform (b : Block , curDepth : => Symbol , isTopLevel : Bool = false ): Block =
8989 def usesStack (r : Result ) = r match
@@ -119,8 +119,21 @@ class StackSafeTransform(depthLimit: Int)(using State):
119119 val hdr2 = hdr.mapConserve(applyHandler)
120120 val bod2 = rewriteBlk(bod)
121121 val rst2 = applyBlock(rst)
122- HandleBlock (l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
122+ if isTopLevel then
123+ val newRes = TempSymbol (N , " res" )
124+ val newHandler = HandleBlock (l2, newRes, par2, args2, cls2, hdr2, bod2, HandleBlockReturn (newRes.asPath))
125+ wrapStackSafe(newHandler, res2, rst2)
126+ else
127+ HandleBlock (l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
128+
123129 case _ => super .applyBlock(b)
130+
131+ override def applyHandler (hdr : Handler ): Handler =
132+ val sym2 = hdr.sym.subst
133+ val resumeSym2 = hdr.resumeSym.subst
134+ val params2 = hdr.params.mapConserve(applyParamList)
135+ val body2 = rewriteBlk(hdr.body)
136+ Handler (sym2, resumeSym2, params2, body2)
124137
125138 override def applyResult2 (r : Result )(k : Result => Block ): Block =
126139 if usesStack(r) then
0 commit comments