@@ -69,6 +69,39 @@ structure Context where
6969 /-- All preceding instructions that need to be `reshape`d. This is used to make `visitFnBody`
7070 tail-recursive. -/
7171 instrs : Array FnBody := #[]
72+ rename : Std.TreeMap VarId Arg fun a b => compare a.idx b.idx := {}
73+
74+ def Context.renameVar (ctx : Context) (a : VarId) : Arg :=
75+ ctx.rename.getD a (.var a)
76+
77+ def Context.renameVar! (ctx : Context) (a : VarId) : VarId :=
78+ match ctx.rename[a]? with
79+ | none => a
80+ | some .erased => panic! "malformed IR"
81+ | some (.var v) => v
82+
83+ def Context.renameArg (ctx : Context) (a : Arg) : Arg :=
84+ match a with
85+ | .erased => .erased
86+ | .var v => ctx.rename.getD v a
87+
88+ def Context.renameArgs (ctx : Context) (a : Array Arg) : Array Arg :=
89+ a.map ctx.renameArg
90+
91+ def Context.insertRename (ctx : Context) (v : VarId) (a : Arg) : Context :=
92+ { ctx with rename := ctx.rename.insert v a }
93+
94+ def Context.insert (var : VarId) (e : Entry) (ctx : Context) : Context :=
95+ { ctx with vars := ctx.vars.insert var e }
96+
97+ def Context.insertIfNew (var : VarId) (e : Entry) (ctx : Context) : Context :=
98+ { ctx with vars := ctx.vars.insertIfNew var e }
99+
100+ def Context.remove (ctx : Context) (v : VarId) : Context :=
101+ { ctx with vars := ctx.vars.erase v }
102+
103+ def Context.addInstr (ctx : Context) (b : FnBody) : Context :=
104+ { ctx with instrs := ctx.instrs.push b }
72105
73106abbrev M := StateM Nat
74107
@@ -111,21 +144,21 @@ def Entry.ofType (ty : IRType) (rc : Int) : Entry :=
111144 | _ => .persistent
112145
113146def Context.addVar (ctx : Context) (x : VarId) (ty : IRType) (rc : Int) : Context :=
114- { ctx with vars := ctx.vars. insertIfNew x (.ofType ty rc) }
147+ ctx. insertIfNew x (.ofType ty rc)
115148
116149/-- Add a variable entry if we need to (i.e. if `ty` is a struct type). -/
117150def Context.addVarBasic (ctx : Context) (x : VarId) (ty : IRType) : Context :=
118151 match ty with
119152 | ty@(.struct _ tys _ _) =>
120153 if tys.any needsRC then
121- { ctx with vars := ctx.vars. insert x (.ofStruct 0 ty 0 ) }
154+ ctx. insert x (.ofStruct 0 ty 0 )
122155 else
123- { ctx with vars := ctx.vars. insert x .persistent }
156+ ctx. insert x .persistent
124157 | .union _ tys =>
125158 if tys.any needsRC then
126- { ctx with vars := ctx.vars. insert x (.unknownUnion tys 0 ) }
159+ ctx. insert x (.unknownUnion tys 0 )
127160 else
128- { ctx with vars := ctx.vars. insert x .persistent }
161+ ctx. insert x .persistent
129162 | _ => ctx
130163
131164/--
@@ -136,16 +169,15 @@ def Context.addProjInfo (ctx : Context) (proj : VarId) (t : IRType) (idx : Nat)
136169 (parent : VarId) (parentEntry : Entry) : Context := Id.run do
137170 let : Inhabited Context := ⟨ctx⟩
138171 let .struct cidx tys fields := parentEntry | unreachable!
139- let .unbound rc := fields[idx]! |
140- -- TODO: this happens occasionally when it feels like cse should've done something
141- let ctx := ctx.addVarBasic proj t
142- return { ctx with vars := ctx.vars.insert parent parentEntry }
143- let ctx := ctx.addVar proj t rc
144- let vars := ctx.vars.insert parent (.struct cidx tys (fields.set! idx (.var proj)))
145- { ctx with vars }
146-
147- def Context.addInstr (ctx : Context) (b : FnBody) : Context :=
148- { ctx with instrs := ctx.instrs.push b }
172+ match fields[idx]! with
173+ | .erased =>
174+ return (ctx.insert parent parentEntry).insertRename proj .erased
175+ | .var v =>
176+ return (ctx.insert parent parentEntry).insertRename proj (.var v)
177+ | .unbound rc =>
178+ let ctx := ctx.addVar proj t rc
179+ let ctx := ctx.insert parent (.struct cidx tys (fields.set! idx (.var proj)))
180+ ctx.addInstr (.vdecl proj t (.proj cidx idx parent) .nil)
149181
150182def Context.emitRCDiff (v : VarId) (check : Bool) (rc : Int) (ctx : Context) : Context :=
151183 if rc > 0 then
@@ -160,16 +192,14 @@ partial def Context.accumulateRCDiff (v : VarId) (check : Bool) (rc : Int) (ctx
160192 match ctx.vars[v]? with
161193 | none => ctx.emitRCDiff v check rc
162194 | some .persistent => ctx
163- | some (.ref check' rc') =>
164- { ctx with vars := ctx.vars.insert v (.ref (check && check') (rc + rc')) }
165- | some (.unknownUnion tys rc') =>
166- { ctx with vars := ctx.vars.insert v (.unknownUnion tys (rc + rc')) }
195+ | some (.ref check' rc') => ctx.insert v (.ref (check && check') (rc + rc'))
196+ | some (.unknownUnion tys rc') => ctx.insert v (.unknownUnion tys (rc + rc'))
167197 | some (.struct cidx tys objs) =>
168198 let (objs, ctx) := Id.run <| StateT.run (s := ctx) <| objs.mapM fun
169199 | .var v' => modifyGet fun ctx => (.var v', ctx.accumulateRCDiff v' true rc)
170200 | .unbound rc' => modifyGet fun ctx => (.unbound (rc + rc'), ctx)
171201 | .erased => return .erased
172- { ctx with vars := ctx.vars. insert v (.struct cidx tys objs) }
202+ ctx. insert v (.struct cidx tys objs)
173203
174204/--
175205Performs all necessary accumulated RC increments and decrements on `v`. If `ignoreInc` is true then
@@ -182,13 +212,13 @@ partial def Context.useVar (ctx : Context) (v : VarId) (ignoreInc : Bool := fals
182212 if ignoreInc ∧ rc ≥ 0 then
183213 return ctx
184214 let ctx := ctx.emitRCDiff v (tys.any (!·.isDefiniteRef)) rc
185- let ctx := { ctx with vars := ctx.vars. insert v (.unknownUnion tys 0 ) }
215+ let ctx := ctx. insert v (.unknownUnion tys 0 )
186216 return ctx
187217 | some (.ref check rc) =>
188218 if ignoreInc ∧ rc ≥ 0 then
189219 return ctx
190220 let ctx := ctx.emitRCDiff v check rc
191- let ctx := { ctx with vars := ctx.vars. insert v (.ref check 0 ) }
221+ let ctx := ctx. insert v (.ref check 0 )
192222 return ctx
193223 | some (.struct cidx tys objs) =>
194224 let mut ctx := ctx
@@ -209,7 +239,7 @@ partial def Context.useVar (ctx : Context) (v : VarId) (ignoreInc : Bool := fals
209239 ctx := ctx.addVar var ty 0
210240 objs := objs.set i (.var var)
211241 | .erased => pure ()
212- ctx := { ctx with vars := ctx.vars. insert v (.struct cidx tys objs) }
242+ ctx := ctx. insert v (.struct cidx tys objs)
213243 return ctx
214244
215245def Context.useArg (ctx : Context) (a : Arg) : M Context :=
@@ -228,7 +258,7 @@ def Context.resetRC (ctx : Context) : Context :=
228258 | _, .unknownUnion tys _ => .unknownUnion tys 0
229259 | _, .ref c _ => .ref c 0
230260 | _, e => e
231- { vars }
261+ { vars, rename := ctx.rename }
232262
233263def Context.finish (ctx : Context) : M Context := do
234264 let mut ctx := ctx
@@ -239,12 +269,9 @@ def Context.finish (ctx : Context) : M Context := do
239269def Context.addCtorKnowledge (ctx : Context) (v : VarId) (cidx : Nat) : Context :=
240270 match ctx.vars[v]? with
241271 | some (.unknownUnion tys rc) =>
242- { ctx with vars := ctx.vars. insert v (.ofStruct cidx tys[cidx]! rc) }
272+ ctx. insert v (.ofStruct cidx tys[cidx]! rc)
243273 | _ => ctx
244274
245- def Context.remove (ctx : Context) (v : VarId) : Context :=
246- { ctx with vars := ctx.vars.erase v }
247-
248275def atConstructorIndex (ty : IRType) (i : Nat) : Array IRType :=
249276 match ty with
250277 | .struct _ tys _ _ => tys
@@ -257,34 +284,65 @@ def atConstructorIndex (ty : IRType) (i : Nat) : Array IRType :=
257284def visitExpr (x : VarId) (t : IRType) (v : Expr) (ctx : Context) : M Context := do
258285 match v with
259286 | .proj c i y =>
287+ let y := ctx.renameVar! y
288+ let v := .proj c i y
260289 match ctx.vars[y]? with
261- | none => return ctx -- just an object projection, nothing to see here
290+ | none => return ctx.addInstr (.vdecl x t v .nil) -- just an object projection, nothing to see here
262291 | some .persistent =>
263- return { ctx with vars := ctx.vars. insert x .persistent }
292+ return ( ctx. insert x .persistent).addInstr (.vdecl x t v .nil)
264293 | some (.unknownUnion tys rc) =>
265294 return ctx.addProjInfo x t i y (.ofStruct c tys[c]! rc)
266295 | some e@(.struct cidx ..) =>
267296 if c ≠ cidx then
268- return ctx.addInstr .unreachable
297+ return ( ctx.addInstr .unreachable).addInstr (.vdecl x t v .nil)
269298 else
270299 return ctx.addProjInfo x t i y e
271300 | some (.ref _ _) =>
272- return ctx.addVarBasic x t
273- | .fap _ args =>
301+ return (ctx.addVarBasic x t).addInstr (.vdecl x t v .nil)
302+ | .fap nm args =>
303+ let args := ctx.renameArgs args
304+ let v := .fap nm args
274305 if args.size = 0 then
275- return { ctx with vars := ctx.vars. insert x .persistent }
306+ return ( ctx. insert x .persistent).addInstr (.vdecl x t v .nil)
276307 else
277- args.foldlM (·.useArg) (ctx.addVarBasic x t)
278- | .ap x args =>
279- args.foldlM (·.useArg) (← ctx.useVar x)
280- | .pap _ args =>
281- args.foldlM (·.useArg) ctx
282- | .isShared x => ctx.useVar x
283- | .reset _ x => ctx.useVar x
284- | .reuse x _ _ args => args.foldlM (·.useArg) (← ctx.useVar x)
285- | .box _ y => (ctx.addVarBasic x t).useVar y
286- | .lit _ | .sproj .. | .uproj .. | .unbox _ => return ctx
308+ return (← args.foldlM (·.useArg) (ctx.addVarBasic x t)).addInstr (.vdecl x t v .nil)
309+ | .ap y args =>
310+ match ctx.renameVar y with
311+ | .erased =>
312+ return ctx.insertRename x .erased
313+ | .var y =>
314+ let args := ctx.renameArgs args
315+ let v := .ap y args
316+ return (← args.foldlM (·.useArg) (← ctx.useVar y)).addInstr (.vdecl x t v .nil)
317+ | .pap nm args =>
318+ let args := ctx.renameArgs args
319+ let v := .pap nm args
320+ return (← args.foldlM (·.useArg) ctx).addInstr (.vdecl x t v .nil)
321+ | .isShared y =>
322+ match ctx.renameVar y with
323+ | .erased => return ctx.addInstr (.vdecl x t (.lit (.num 1 )) .nil)
324+ | .var y =>
325+ let v := .isShared y
326+ return (← ctx.useVar y).addInstr (.vdecl x t v .nil)
327+ | .reset n y =>
328+ let y := ctx.renameVar! y
329+ let v := .reset n y
330+ return (← ctx.useVar y).addInstr (.vdecl x t v .nil)
331+ | .reuse y i u args =>
332+ let y := ctx.renameVar! y
333+ let v := .reuse y i u args
334+ return (← args.foldlM (·.useArg) (← ctx.useVar x)).addInstr (.vdecl x t v .nil)
335+ | .box ty y =>
336+ let y := ctx.renameVar! y
337+ let v := .box ty y
338+ return (← (ctx.addVarBasic x t).useVar y).addInstr (.vdecl x t v .nil)
339+ | .lit _ => return ctx.addInstr (.vdecl x t v .nil)
340+ | .sproj c n o y => return ctx.addInstr (.vdecl x t (.sproj c n o (ctx.renameVar! y)) .nil)
341+ | .uproj c i y => return ctx.addInstr (.vdecl x t (.uproj c i (ctx.renameVar! y)) .nil)
342+ | .unbox y => return ctx.addInstr (.vdecl x t (.unbox (ctx.renameVar! y)) .nil)
287343 | .ctor c args =>
344+ let args := ctx.renameArgs args
345+ let v := .ctor c args
288346 if t.isStruct then
289347 let tys := atConstructorIndex t c.cidx
290348 let e := .struct c.cidx tys <| Vector.ofFn fun ⟨i, _⟩ =>
@@ -295,16 +353,15 @@ def visitExpr (x : VarId) (t : IRType) (v : Expr) (ctx : Context) : M Context :=
295353 match args[i]! with
296354 | .var v => ctx.addVar v tys[i] 0
297355 | .erased => ctx
298- let ctx := { ctx with vars := ctx.vars. insert x e }
299- return ctx
356+ let ctx := ctx. insert x e
357+ return ctx.addInstr (.vdecl x t v .nil)
300358 else
301- args.foldlM (·.useArg) ctx
359+ return (← args.foldlM (·.useArg) ctx).addInstr (.vdecl x t v .nil)
302360
303361partial def visitFnBody (body : FnBody) (ctx : Context) : M FnBody := do
304362 match body with
305363 | .vdecl x t v b =>
306364 let ctx ← visitExpr x t v ctx
307- let ctx := ctx.addInstr (.vdecl x t v .nil)
308365 visitFnBody b ctx
309366 | .jdecl j xs v b =>
310367 let v ← visitFnBody v (ctx.resetRC.addParams xs)
@@ -314,21 +371,33 @@ partial def visitFnBody (body : FnBody) (ctx : Context) : M FnBody := do
314371 -- increment on persistent value, ignore
315372 visitFnBody b ctx
316373 else
317- visitFnBody b (ctx.accumulateRCDiff x c n)
374+ match ctx.renameVar x with
375+ | .erased => visitFnBody b ctx
376+ | .var x =>
377+ visitFnBody b (ctx.accumulateRCDiff x c n)
318378 | .dec x n c p b =>
319379 if p then
320380 -- decrement on persistent value, ignore
321381 visitFnBody b ctx
322382 else
323- let ctx := ctx.accumulateRCDiff x c (-n)
324- -- we can delay increments but we shouldn't delay deallocations
325- let ctx ← ctx.useVar x (ignoreInc := true )
326- visitFnBody b ctx
383+ match ctx.renameVar x with
384+ | .erased => visitFnBody b ctx
385+ | .var x =>
386+ let ctx := ctx.accumulateRCDiff x c (-n)
387+ -- we can delay increments but we shouldn't delay deallocations
388+ let ctx ← ctx.useVar x (ignoreInc := true )
389+ visitFnBody b ctx
327390 | .unreachable => return reshape ctx.instrs body
328- | .ret _ | .jmp _ _ =>
391+ | .ret a =>
392+ let a := ctx.renameArg a
393+ let ctx ← ctx.finish
394+ return reshape ctx.instrs (.ret a)
395+ | .jmp jp args =>
396+ let args := ctx.renameArgs args
329397 let ctx ← ctx.finish
330- return reshape ctx.instrs body
398+ return reshape ctx.instrs (.jmp jp args)
331399 | .case nm x xTy alts =>
400+ let x := ctx.renameVar! x
332401 if let some (.struct cidx _ _) := ctx.vars[x]? then
333402 -- because apparently this isn't guaranteed?!
334403 let body? := alts.findSome? fun alt =>
@@ -349,13 +418,26 @@ partial def visitFnBody (body : FnBody) (ctx : Context) : M FnBody := do
349418 return reshape instrs body
350419 | .del v b =>
351420 visitFnBody b (ctx.remove v |>.addInstr (.del v .nil))
352- | .sset v _ _ _ _ _ b
353- | .uset v _ _ _ b
354- | .setTag v _ b =>
421+ | .sset v c i o y t b =>
422+ let v := ctx.renameVar! v
423+ let y := ctx.renameVar! y
424+ let ctx ← ctx.useVar v
425+ let ctx := ctx.addInstr (.sset v c i o y t .nil)
426+ visitFnBody b ctx
427+ | .uset v c i y b =>
428+ let v := ctx.renameVar! v
429+ let y := ctx.renameVar! y
430+ let ctx ← ctx.useVar v
431+ let ctx := ctx.addInstr (.uset v c i y .nil)
432+ visitFnBody b ctx
433+ | .setTag v i b =>
434+ let v := ctx.renameVar! v
355435 let ctx ← ctx.useVar v
356- let ctx := ctx.addInstr (body.setBody .nil)
436+ let ctx := ctx.addInstr (.setTag v i .nil)
357437 visitFnBody b ctx
358438 | .set v i a b =>
439+ let v := ctx.renameVar! v
440+ let a := ctx.renameArg a
359441 let ctx ← ctx.useVar v
360442 let ctx ← ctx.useArg a
361443 let ctx := ctx.addInstr (.set v i a .nil)
0 commit comments