Skip to content

Commit 90e757e

Browse files
committed
fix struct rc pass
1 parent b2a4183 commit 90e757e

File tree

2 files changed

+142
-60
lines changed

2 files changed

+142
-60
lines changed

src/Lean/Compiler/IR.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do
7474
decls ← updateSorryDep decls
7575
logDecls `result decls
7676
checkDecls decls
77-
decls ← toposortDecls decls
7877
if (← getOptions).getBool (tracePrefixOptionName ++ `struct_rc) ||
7978
(← getOptions).getBool tracePrefixOptionName then
8079
let decls2 := decls.map StructRC.visitDecl
8180
log (LogEntry.step `struct_rc decls2)
81+
decls ← toposortDecls decls
8282
addDecls decls
8383
inferMeta decls
8484
return decls

src/Lean/Compiler/IR/StructRC.lean

Lines changed: 141 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,39 @@ structure Context where
6969
/-- All preceding instructions that need to be `reshape`d. This is used to make `visitFnBody`
7070
tail-recursive. -/
7171
instrs : Array FnBody := #[]
72+
rename : Std.TreeMap VarId Arg fun a b => compare a.idx b.idx := {}
73+
74+
def Context.renameVar (ctx : Context) (a : VarId) : Arg :=
75+
ctx.rename.getD a (.var a)
76+
77+
def Context.renameVar! (ctx : Context) (a : VarId) : VarId :=
78+
match ctx.rename[a]? with
79+
| none => a
80+
| some .erased => panic! "malformed IR"
81+
| some (.var v) => v
82+
83+
def Context.renameArg (ctx : Context) (a : Arg) : Arg :=
84+
match a with
85+
| .erased => .erased
86+
| .var v => ctx.rename.getD v a
87+
88+
def Context.renameArgs (ctx : Context) (a : Array Arg) : Array Arg :=
89+
a.map ctx.renameArg
90+
91+
def Context.insertRename (ctx : Context) (v : VarId) (a : Arg) : Context :=
92+
{ ctx with rename := ctx.rename.insert v a }
93+
94+
def Context.insert (var : VarId) (e : Entry) (ctx : Context) : Context :=
95+
{ ctx with vars := ctx.vars.insert var e }
96+
97+
def Context.insertIfNew (var : VarId) (e : Entry) (ctx : Context) : Context :=
98+
{ ctx with vars := ctx.vars.insertIfNew var e }
99+
100+
def Context.remove (ctx : Context) (v : VarId) : Context :=
101+
{ ctx with vars := ctx.vars.erase v }
102+
103+
def Context.addInstr (ctx : Context) (b : FnBody) : Context :=
104+
{ ctx with instrs := ctx.instrs.push b }
72105

73106
abbrev M := StateM Nat
74107

@@ -111,21 +144,21 @@ def Entry.ofType (ty : IRType) (rc : Int) : Entry :=
111144
| _ => .persistent
112145

113146
def Context.addVar (ctx : Context) (x : VarId) (ty : IRType) (rc : Int) : Context :=
114-
{ ctx with vars := ctx.vars.insertIfNew x (.ofType ty rc) }
147+
ctx.insertIfNew x (.ofType ty rc)
115148

116149
/-- Add a variable entry if we need to (i.e. if `ty` is a struct type). -/
117150
def Context.addVarBasic (ctx : Context) (x : VarId) (ty : IRType) : Context :=
118151
match ty with
119152
| ty@(.struct _ tys _ _) =>
120153
if tys.any needsRC then
121-
{ ctx with vars := ctx.vars.insert x (.ofStruct 0 ty 0) }
154+
ctx.insert x (.ofStruct 0 ty 0)
122155
else
123-
{ ctx with vars := ctx.vars.insert x .persistent }
156+
ctx.insert x .persistent
124157
| .union _ tys =>
125158
if tys.any needsRC then
126-
{ ctx with vars := ctx.vars.insert x (.unknownUnion tys 0) }
159+
ctx.insert x (.unknownUnion tys 0)
127160
else
128-
{ ctx with vars := ctx.vars.insert x .persistent }
161+
ctx.insert x .persistent
129162
| _ => ctx
130163

131164
/--
@@ -136,16 +169,15 @@ def Context.addProjInfo (ctx : Context) (proj : VarId) (t : IRType) (idx : Nat)
136169
(parent : VarId) (parentEntry : Entry) : Context := Id.run do
137170
let : Inhabited Context := ⟨ctx⟩
138171
let .struct cidx tys fields := parentEntry | unreachable!
139-
let .unbound rc := fields[idx]! |
140-
-- TODO: this happens occasionally when it feels like cse should've done something
141-
let ctx := ctx.addVarBasic proj t
142-
return { ctx with vars := ctx.vars.insert parent parentEntry }
143-
let ctx := ctx.addVar proj t rc
144-
let vars := ctx.vars.insert parent (.struct cidx tys (fields.set! idx (.var proj)))
145-
{ ctx with vars }
146-
147-
def Context.addInstr (ctx : Context) (b : FnBody) : Context :=
148-
{ ctx with instrs := ctx.instrs.push b }
172+
match fields[idx]! with
173+
| .erased =>
174+
return (ctx.insert parent parentEntry).insertRename proj .erased
175+
| .var v =>
176+
return (ctx.insert parent parentEntry).insertRename proj (.var v)
177+
| .unbound rc =>
178+
let ctx := ctx.addVar proj t rc
179+
let ctx := ctx.insert parent (.struct cidx tys (fields.set! idx (.var proj)))
180+
ctx.addInstr (.vdecl proj t (.proj cidx idx parent) .nil)
149181

150182
def Context.emitRCDiff (v : VarId) (check : Bool) (rc : Int) (ctx : Context) : Context :=
151183
if rc > 0 then
@@ -160,16 +192,14 @@ partial def Context.accumulateRCDiff (v : VarId) (check : Bool) (rc : Int) (ctx
160192
match ctx.vars[v]? with
161193
| none => ctx.emitRCDiff v check rc
162194
| some .persistent => ctx
163-
| some (.ref check' rc') =>
164-
{ ctx with vars := ctx.vars.insert v (.ref (check && check') (rc + rc')) }
165-
| some (.unknownUnion tys rc') =>
166-
{ ctx with vars := ctx.vars.insert v (.unknownUnion tys (rc + rc')) }
195+
| some (.ref check' rc') => ctx.insert v (.ref (check && check') (rc + rc'))
196+
| some (.unknownUnion tys rc') => ctx.insert v (.unknownUnion tys (rc + rc'))
167197
| some (.struct cidx tys objs) =>
168198
let (objs, ctx) := Id.run <| StateT.run (s := ctx) <| objs.mapM fun
169199
| .var v' => modifyGet fun ctx => (.var v', ctx.accumulateRCDiff v' true rc)
170200
| .unbound rc' => modifyGet fun ctx => (.unbound (rc + rc'), ctx)
171201
| .erased => return .erased
172-
{ ctx with vars := ctx.vars.insert v (.struct cidx tys objs) }
202+
ctx.insert v (.struct cidx tys objs)
173203

174204
/--
175205
Performs all necessary accumulated RC increments and decrements on `v`. If `ignoreInc` is true then
@@ -182,13 +212,13 @@ partial def Context.useVar (ctx : Context) (v : VarId) (ignoreInc : Bool := fals
182212
if ignoreInc ∧ rc ≥ 0 then
183213
return ctx
184214
let ctx := ctx.emitRCDiff v (tys.any (!·.isDefiniteRef)) rc
185-
let ctx := { ctx with vars := ctx.vars.insert v (.unknownUnion tys 0) }
215+
let ctx := ctx.insert v (.unknownUnion tys 0)
186216
return ctx
187217
| some (.ref check rc) =>
188218
if ignoreInc ∧ rc ≥ 0 then
189219
return ctx
190220
let ctx := ctx.emitRCDiff v check rc
191-
let ctx := { ctx with vars := ctx.vars.insert v (.ref check 0) }
221+
let ctx := ctx.insert v (.ref check 0)
192222
return ctx
193223
| some (.struct cidx tys objs) =>
194224
let mut ctx := ctx
@@ -209,7 +239,7 @@ partial def Context.useVar (ctx : Context) (v : VarId) (ignoreInc : Bool := fals
209239
ctx := ctx.addVar var ty 0
210240
objs := objs.set i (.var var)
211241
| .erased => pure ()
212-
ctx := { ctx with vars := ctx.vars.insert v (.struct cidx tys objs) }
242+
ctx := ctx.insert v (.struct cidx tys objs)
213243
return ctx
214244

215245
def Context.useArg (ctx : Context) (a : Arg) : M Context :=
@@ -228,7 +258,7 @@ def Context.resetRC (ctx : Context) : Context :=
228258
| _, .unknownUnion tys _ => .unknownUnion tys 0
229259
| _, .ref c _ => .ref c 0
230260
| _, e => e
231-
{ vars }
261+
{ vars, rename := ctx.rename }
232262

233263
def Context.finish (ctx : Context) : M Context := do
234264
let mut ctx := ctx
@@ -239,12 +269,9 @@ def Context.finish (ctx : Context) : M Context := do
239269
def Context.addCtorKnowledge (ctx : Context) (v : VarId) (cidx : Nat) : Context :=
240270
match ctx.vars[v]? with
241271
| some (.unknownUnion tys rc) =>
242-
{ ctx with vars := ctx.vars.insert v (.ofStruct cidx tys[cidx]! rc) }
272+
ctx.insert v (.ofStruct cidx tys[cidx]! rc)
243273
| _ => ctx
244274

245-
def Context.remove (ctx : Context) (v : VarId) : Context :=
246-
{ ctx with vars := ctx.vars.erase v }
247-
248275
def atConstructorIndex (ty : IRType) (i : Nat) : Array IRType :=
249276
match ty with
250277
| .struct _ tys _ _ => tys
@@ -257,34 +284,65 @@ def atConstructorIndex (ty : IRType) (i : Nat) : Array IRType :=
257284
def visitExpr (x : VarId) (t : IRType) (v : Expr) (ctx : Context) : M Context := do
258285
match v with
259286
| .proj c i y =>
287+
let y := ctx.renameVar! y
288+
let v := .proj c i y
260289
match ctx.vars[y]? with
261-
| none => return ctx -- just an object projection, nothing to see here
290+
| none => return ctx.addInstr (.vdecl x t v .nil) -- just an object projection, nothing to see here
262291
| some .persistent =>
263-
return { ctx with vars := ctx.vars.insert x .persistent }
292+
return (ctx.insert x .persistent).addInstr (.vdecl x t v .nil)
264293
| some (.unknownUnion tys rc) =>
265294
return ctx.addProjInfo x t i y (.ofStruct c tys[c]! rc)
266295
| some e@(.struct cidx ..) =>
267296
if c ≠ cidx then
268-
return ctx.addInstr .unreachable
297+
return (ctx.addInstr .unreachable).addInstr (.vdecl x t v .nil)
269298
else
270299
return ctx.addProjInfo x t i y e
271300
| some (.ref _ _) =>
272-
return ctx.addVarBasic x t
273-
| .fap _ args =>
301+
return (ctx.addVarBasic x t).addInstr (.vdecl x t v .nil)
302+
| .fap nm args =>
303+
let args := ctx.renameArgs args
304+
let v := .fap nm args
274305
if args.size = 0 then
275-
return { ctx with vars := ctx.vars.insert x .persistent }
306+
return (ctx.insert x .persistent).addInstr (.vdecl x t v .nil)
276307
else
277-
args.foldlM (·.useArg) (ctx.addVarBasic x t)
278-
| .ap x args =>
279-
args.foldlM (·.useArg) (← ctx.useVar x)
280-
| .pap _ args =>
281-
args.foldlM (·.useArg) ctx
282-
| .isShared x => ctx.useVar x
283-
| .reset _ x => ctx.useVar x
284-
| .reuse x _ _ args => args.foldlM (·.useArg) (← ctx.useVar x)
285-
| .box _ y => (ctx.addVarBasic x t).useVar y
286-
| .lit _ | .sproj .. | .uproj .. | .unbox _ => return ctx
308+
return (← args.foldlM (·.useArg) (ctx.addVarBasic x t)).addInstr (.vdecl x t v .nil)
309+
| .ap y args =>
310+
match ctx.renameVar y with
311+
| .erased =>
312+
return ctx.insertRename x .erased
313+
| .var y =>
314+
let args := ctx.renameArgs args
315+
let v := .ap y args
316+
return (← args.foldlM (·.useArg) (← ctx.useVar y)).addInstr (.vdecl x t v .nil)
317+
| .pap nm args =>
318+
let args := ctx.renameArgs args
319+
let v := .pap nm args
320+
return (← args.foldlM (·.useArg) ctx).addInstr (.vdecl x t v .nil)
321+
| .isShared y =>
322+
match ctx.renameVar y with
323+
| .erased => return ctx.addInstr (.vdecl x t (.lit (.num 1)) .nil)
324+
| .var y =>
325+
let v := .isShared y
326+
return (← ctx.useVar y).addInstr (.vdecl x t v .nil)
327+
| .reset n y =>
328+
let y := ctx.renameVar! y
329+
let v := .reset n y
330+
return (← ctx.useVar y).addInstr (.vdecl x t v .nil)
331+
| .reuse y i u args =>
332+
let y := ctx.renameVar! y
333+
let v := .reuse y i u args
334+
return (← args.foldlM (·.useArg) (← ctx.useVar x)).addInstr (.vdecl x t v .nil)
335+
| .box ty y =>
336+
let y := ctx.renameVar! y
337+
let v := .box ty y
338+
return (← (ctx.addVarBasic x t).useVar y).addInstr (.vdecl x t v .nil)
339+
| .lit _ => return ctx.addInstr (.vdecl x t v .nil)
340+
| .sproj c n o y => return ctx.addInstr (.vdecl x t (.sproj c n o (ctx.renameVar! y)) .nil)
341+
| .uproj c i y => return ctx.addInstr (.vdecl x t (.uproj c i (ctx.renameVar! y)) .nil)
342+
| .unbox y => return ctx.addInstr (.vdecl x t (.unbox (ctx.renameVar! y)) .nil)
287343
| .ctor c args =>
344+
let args := ctx.renameArgs args
345+
let v := .ctor c args
288346
if t.isStruct then
289347
let tys := atConstructorIndex t c.cidx
290348
let e := .struct c.cidx tys <| Vector.ofFn fun ⟨i, _⟩ =>
@@ -295,16 +353,15 @@ def visitExpr (x : VarId) (t : IRType) (v : Expr) (ctx : Context) : M Context :=
295353
match args[i]! with
296354
| .var v => ctx.addVar v tys[i] 0
297355
| .erased => ctx
298-
let ctx := { ctx with vars := ctx.vars.insert x e }
299-
return ctx
356+
let ctx := ctx.insert x e
357+
return ctx.addInstr (.vdecl x t v .nil)
300358
else
301-
args.foldlM (·.useArg) ctx
359+
return (← args.foldlM (·.useArg) ctx).addInstr (.vdecl x t v .nil)
302360

303361
partial def visitFnBody (body : FnBody) (ctx : Context) : M FnBody := do
304362
match body with
305363
| .vdecl x t v b =>
306364
let ctx ← visitExpr x t v ctx
307-
let ctx := ctx.addInstr (.vdecl x t v .nil)
308365
visitFnBody b ctx
309366
| .jdecl j xs v b =>
310367
let v ← visitFnBody v (ctx.resetRC.addParams xs)
@@ -314,21 +371,33 @@ partial def visitFnBody (body : FnBody) (ctx : Context) : M FnBody := do
314371
-- increment on persistent value, ignore
315372
visitFnBody b ctx
316373
else
317-
visitFnBody b (ctx.accumulateRCDiff x c n)
374+
match ctx.renameVar x with
375+
| .erased => visitFnBody b ctx
376+
| .var x =>
377+
visitFnBody b (ctx.accumulateRCDiff x c n)
318378
| .dec x n c p b =>
319379
if p then
320380
-- decrement on persistent value, ignore
321381
visitFnBody b ctx
322382
else
323-
let ctx := ctx.accumulateRCDiff x c (-n)
324-
-- we can delay increments but we shouldn't delay deallocations
325-
let ctx ← ctx.useVar x (ignoreInc := true)
326-
visitFnBody b ctx
383+
match ctx.renameVar x with
384+
| .erased => visitFnBody b ctx
385+
| .var x =>
386+
let ctx := ctx.accumulateRCDiff x c (-n)
387+
-- we can delay increments but we shouldn't delay deallocations
388+
let ctx ← ctx.useVar x (ignoreInc := true)
389+
visitFnBody b ctx
327390
| .unreachable => return reshape ctx.instrs body
328-
| .ret _ | .jmp _ _ =>
391+
| .ret a =>
392+
let a := ctx.renameArg a
393+
let ctx ← ctx.finish
394+
return reshape ctx.instrs (.ret a)
395+
| .jmp jp args =>
396+
let args := ctx.renameArgs args
329397
let ctx ← ctx.finish
330-
return reshape ctx.instrs body
398+
return reshape ctx.instrs (.jmp jp args)
331399
| .case nm x xTy alts =>
400+
let x := ctx.renameVar! x
332401
if let some (.struct cidx _ _) := ctx.vars[x]? then
333402
-- because apparently this isn't guaranteed?!
334403
let body? := alts.findSome? fun alt =>
@@ -349,13 +418,26 @@ partial def visitFnBody (body : FnBody) (ctx : Context) : M FnBody := do
349418
return reshape instrs body
350419
| .del v b =>
351420
visitFnBody b (ctx.remove v |>.addInstr (.del v .nil))
352-
| .sset v _ _ _ _ _ b
353-
| .uset v _ _ _ b
354-
| .setTag v _ b =>
421+
| .sset v c i o y t b =>
422+
let v := ctx.renameVar! v
423+
let y := ctx.renameVar! y
424+
let ctx ← ctx.useVar v
425+
let ctx := ctx.addInstr (.sset v c i o y t .nil)
426+
visitFnBody b ctx
427+
| .uset v c i y b =>
428+
let v := ctx.renameVar! v
429+
let y := ctx.renameVar! y
430+
let ctx ← ctx.useVar v
431+
let ctx := ctx.addInstr (.uset v c i y .nil)
432+
visitFnBody b ctx
433+
| .setTag v i b =>
434+
let v := ctx.renameVar! v
355435
let ctx ← ctx.useVar v
356-
let ctx := ctx.addInstr (body.setBody .nil)
436+
let ctx := ctx.addInstr (.setTag v i .nil)
357437
visitFnBody b ctx
358438
| .set v i a b =>
439+
let v := ctx.renameVar! v
440+
let a := ctx.renameArg a
359441
let ctx ← ctx.useVar v
360442
let ctx ← ctx.useArg a
361443
let ctx := ctx.addInstr (.set v i a .nil)

0 commit comments

Comments
 (0)