Skip to content

Commit 6dd08cb

Browse files
authored
Rollup merge of rust-lang#147533 - cjgillot:coro-late-renumber, r=davidtwco
Renumber return local after state transform The current implementation of `StateTransform` renames `_0` before analyzing liveness. This is inconsistent, as a `return` terminator hardcodes a read of `_0`. This PR proposes to perform such rename *after* analyzing the body, in fact after the whole transform. The implementation is not much more complicated.
2 parents 2a10082 + 6d800ae commit 6dd08cb

File tree

4 files changed

+138
-100
lines changed

4 files changed

+138
-100
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 96 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ use rustc_hir::lang_items::LangItem;
6868
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
6969
use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
7070
use rustc_index::{Idx, IndexVec};
71-
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
71+
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
7272
use rustc_middle::mir::*;
7373
use rustc_middle::ty::util::Discr;
7474
use rustc_middle::ty::{
@@ -110,6 +110,8 @@ impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
110110
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
111111
if *local == self.from {
112112
*local = self.to;
113+
} else if *local == self.to {
114+
*local = self.from;
113115
}
114116
}
115117

@@ -159,13 +161,15 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
159161
}
160162
}
161163

164+
#[tracing::instrument(level = "trace", skip(tcx))]
162165
fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
163166
place.local = new_base.local;
164167

165168
let mut new_projection = new_base.projection.to_vec();
166169
new_projection.append(&mut place.projection.to_vec());
167170

168171
place.projection = tcx.mk_place_elems(&new_projection);
172+
tracing::trace!(?place);
169173
}
170174

171175
const SELF_ARG: Local = Local::from_u32(1);
@@ -204,8 +208,8 @@ struct TransformVisitor<'tcx> {
204208
// The set of locals that have no `StorageLive`/`StorageDead` annotations.
205209
always_live_locals: DenseBitSet<Local>,
206210

207-
// The original RETURN_PLACE local
208-
old_ret_local: Local,
211+
// New local we just create to hold the `CoroutineState` value.
212+
new_ret_local: Local,
209213

210214
old_yield_ty: Ty<'tcx>,
211215

@@ -270,6 +274,7 @@ impl<'tcx> TransformVisitor<'tcx> {
270274
// `core::ops::CoroutineState` only has single element tuple variants,
271275
// so we can just write to the downcasted first field and then set the
272276
// discriminant to the appropriate variant.
277+
#[tracing::instrument(level = "trace", skip(self, statements))]
273278
fn make_state(
274279
&self,
275280
val: Operand<'tcx>,
@@ -341,13 +346,15 @@ impl<'tcx> TransformVisitor<'tcx> {
341346
}
342347
};
343348

349+
// Assign to `new_ret_local`, which will be replaced by `RETURN_PLACE` later.
344350
statements.push(Statement::new(
345351
source_info,
346-
StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
352+
StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
347353
));
348354
}
349355

350356
// Create a Place referencing a coroutine struct field
357+
#[tracing::instrument(level = "trace", skip(self), ret)]
351358
fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
352359
let self_place = Place::from(SELF_ARG);
353360
let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
@@ -358,6 +365,7 @@ impl<'tcx> TransformVisitor<'tcx> {
358365
}
359366

360367
// Create a statement which changes the discriminant
368+
#[tracing::instrument(level = "trace", skip(self))]
361369
fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
362370
let self_place = Place::from(SELF_ARG);
363371
Statement::new(
@@ -370,6 +378,7 @@ impl<'tcx> TransformVisitor<'tcx> {
370378
}
371379

372380
// Create a statement which reads the discriminant into a temporary
381+
#[tracing::instrument(level = "trace", skip(self, body))]
373382
fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
374383
let temp_decl = LocalDecl::new(self.discr_ty, body.span);
375384
let local_decls_len = body.local_decls.push(temp_decl);
@@ -382,55 +391,83 @@ impl<'tcx> TransformVisitor<'tcx> {
382391
);
383392
(assign, temp)
384393
}
394+
395+
/// Swaps all references of `old_local` and `new_local`.
396+
#[tracing::instrument(level = "trace", skip(self, body))]
397+
fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
398+
body.local_decls.swap(old_local, new_local);
399+
400+
let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
401+
visitor.visit_body(body);
402+
for suspension in &mut self.suspension_points {
403+
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
404+
let location = Location { block: START_BLOCK, statement_index: 0 };
405+
visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
406+
}
407+
}
385408
}
386409

387410
impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
388411
fn tcx(&self) -> TyCtxt<'tcx> {
389412
self.tcx
390413
}
391414

392-
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
415+
#[tracing::instrument(level = "trace", skip(self), ret)]
416+
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
393417
assert!(!self.remap.contains(*local));
394418
}
395419

396-
fn visit_place(
397-
&mut self,
398-
place: &mut Place<'tcx>,
399-
_context: PlaceContext,
400-
_location: Location,
401-
) {
420+
#[tracing::instrument(level = "trace", skip(self), ret)]
421+
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
402422
// Replace an Local in the remap with a coroutine struct access
403423
if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
404424
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
405425
}
406426
}
407427

408-
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
428+
#[tracing::instrument(level = "trace", skip(self, stmt), ret)]
429+
fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
409430
// Remove StorageLive and StorageDead statements for remapped locals
410-
for s in &mut data.statements {
411-
if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = s.kind
412-
&& self.remap.contains(l)
413-
{
414-
s.make_nop(true);
415-
}
431+
if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
432+
&& self.remap.contains(l)
433+
{
434+
stmt.make_nop(true);
416435
}
436+
self.super_statement(stmt, location);
437+
}
417438

418-
let ret_val = match data.terminator().kind {
439+
#[tracing::instrument(level = "trace", skip(self, term), ret)]
440+
fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
441+
if let TerminatorKind::Return = term.kind {
442+
// `visit_basic_block_data` introduces `Return` terminators which read `RETURN_PLACE`.
443+
// But this `RETURN_PLACE` is already remapped, so we should not touch it again.
444+
return;
445+
}
446+
self.super_terminator(term, location);
447+
}
448+
449+
#[tracing::instrument(level = "trace", skip(self, data), ret)]
450+
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
451+
match data.terminator().kind {
419452
TerminatorKind::Return => {
420-
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
421-
}
422-
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
423-
Some((false, Some((resume, resume_arg)), value.clone(), drop))
453+
let source_info = data.terminator().source_info;
454+
// We must assign the value first in case it gets declared dead below
455+
self.make_state(
456+
Operand::Move(Place::return_place()),
457+
source_info,
458+
true,
459+
&mut data.statements,
460+
);
461+
// Return state.
462+
let state = VariantIdx::new(CoroutineArgs::RETURNED);
463+
data.statements.push(self.set_discr(state, source_info));
464+
data.terminator_mut().kind = TerminatorKind::Return;
424465
}
425-
_ => None,
426-
};
427-
428-
if let Some((is_return, resume, v, drop)) = ret_val {
429-
let source_info = data.terminator().source_info;
430-
// We must assign the value first in case it gets declared dead below
431-
self.make_state(v, source_info, is_return, &mut data.statements);
432-
let state = if let Some((resume, mut resume_arg)) = resume {
433-
// Yield
466+
TerminatorKind::Yield { ref value, resume, mut resume_arg, drop } => {
467+
let source_info = data.terminator().source_info;
468+
// We must assign the value first in case it gets declared dead below
469+
self.make_state(value.clone(), source_info, false, &mut data.statements);
470+
// Yield state.
434471
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
435472

436473
// The resume arg target location might itself be remapped if its base local is
@@ -461,13 +498,11 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
461498
storage_liveness,
462499
});
463500

464-
VariantIdx::new(state)
465-
} else {
466-
// Return
467-
VariantIdx::new(CoroutineArgs::RETURNED) // state for returned
468-
};
469-
data.statements.push(self.set_discr(state, source_info));
470-
data.terminator_mut().kind = TerminatorKind::Return;
501+
let state = VariantIdx::new(state);
502+
data.statements.push(self.set_discr(state, source_info));
503+
data.terminator_mut().kind = TerminatorKind::Return;
504+
}
505+
_ => {}
471506
}
472507

473508
self.super_basic_block_data(block, data);
@@ -483,6 +518,7 @@ fn make_aggregate_adt<'tcx>(
483518
Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
484519
}
485520

521+
#[tracing::instrument(level = "trace", skip(tcx, body))]
486522
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
487523
let coroutine_ty = body.local_decls.raw[1].ty;
488524

@@ -495,6 +531,7 @@ fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Bo
495531
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
496532
}
497533

534+
#[tracing::instrument(level = "trace", skip(tcx, body))]
498535
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
499536
let ref_coroutine_ty = body.local_decls.raw[1].ty;
500537

@@ -511,27 +548,6 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
511548
.visit_body(body);
512549
}
513550

514-
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
515-
///
516-
/// `local` will be changed to a new local decl with type `ty`.
517-
///
518-
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
519-
/// valid value to it before its first use.
520-
fn replace_local<'tcx>(
521-
local: Local,
522-
ty: Ty<'tcx>,
523-
body: &mut Body<'tcx>,
524-
tcx: TyCtxt<'tcx>,
525-
) -> Local {
526-
let new_decl = LocalDecl::new(ty, body.span);
527-
let new_local = body.local_decls.push(new_decl);
528-
body.local_decls.swap(local, new_local);
529-
530-
RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);
531-
532-
new_local
533-
}
534-
535551
/// Transforms the `body` of the coroutine applying the following transforms:
536552
///
537553
/// - Eliminates all the `get_context` calls that async lowering created.
@@ -553,6 +569,7 @@ fn replace_local<'tcx>(
553569
/// The async lowering step and the type / lifetime inference / checking are
554570
/// still using the `ResumeTy` indirection for the time being, and that indirection
555571
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
572+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
556573
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
557574
let context_mut_ref = Ty::new_task_context(tcx);
558575

@@ -606,6 +623,7 @@ fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local
606623
}
607624

608625
#[cfg_attr(not(debug_assertions), allow(unused))]
626+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
609627
fn replace_resume_ty_local<'tcx>(
610628
tcx: TyCtxt<'tcx>,
611629
body: &mut Body<'tcx>,
@@ -670,6 +688,7 @@ struct LivenessInfo {
670688
/// case none exist, the local is considered to be always live.
671689
/// - a local has to be stored if it is either directly used after the
672690
/// the suspend point, or if it is live and has been previously borrowed.
691+
#[tracing::instrument(level = "trace", skip(tcx, body))]
673692
fn locals_live_across_suspend_points<'tcx>(
674693
tcx: TyCtxt<'tcx>,
675694
body: &Body<'tcx>,
@@ -945,6 +964,7 @@ impl StorageConflictVisitor<'_, '_> {
945964
}
946965
}
947966

967+
#[tracing::instrument(level = "trace", skip(liveness, body))]
948968
fn compute_layout<'tcx>(
949969
liveness: LivenessInfo,
950970
body: &Body<'tcx>,
@@ -1049,7 +1069,9 @@ fn compute_layout<'tcx>(
10491069
variant_source_info,
10501070
storage_conflicts,
10511071
};
1072+
debug!(?remap);
10521073
debug!(?layout);
1074+
debug!(?storage_liveness);
10531075

10541076
(remap, layout, storage_liveness)
10551077
}
@@ -1221,6 +1243,7 @@ fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
12211243
}
12221244
}
12231245

1246+
#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
12241247
fn create_coroutine_resume_function<'tcx>(
12251248
tcx: TyCtxt<'tcx>,
12261249
transform: TransformVisitor<'tcx>,
@@ -1299,7 +1322,7 @@ fn create_coroutine_resume_function<'tcx>(
12991322
}
13001323

13011324
/// An operation that can be performed on a coroutine.
1302-
#[derive(PartialEq, Copy, Clone)]
1325+
#[derive(PartialEq, Copy, Clone, Debug)]
13031326
enum Operation {
13041327
Resume,
13051328
Drop,
@@ -1314,6 +1337,7 @@ impl Operation {
13141337
}
13151338
}
13161339

1340+
#[tracing::instrument(level = "trace", skip(transform, body))]
13171341
fn create_cases<'tcx>(
13181342
body: &mut Body<'tcx>,
13191343
transform: &TransformVisitor<'tcx>,
@@ -1445,6 +1469,8 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14451469
// This only applies to coroutines
14461470
return;
14471471
};
1472+
tracing::trace!(def_id = ?body.source.def_id());
1473+
14481474
let old_ret_ty = body.return_ty();
14491475

14501476
assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
@@ -1491,10 +1517,6 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14911517
}
14921518
};
14931519

1494-
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
1495-
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1496-
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
1497-
14981520
// We need to insert clean drop for unresumed state and perform drop elaboration
14991521
// (finally in open_drop_for_tuple) before async drop expansion.
15001522
// Async drops, produced by this drop elaboration, will be expanded,
@@ -1541,6 +1563,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15411563

15421564
let can_return = can_return(tcx, body, body.typing_env(tcx));
15431565

1566+
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1567+
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1568+
let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1569+
tracing::trace!(?new_ret_local);
1570+
15441571
// Run the transformation which converts Places from Local to coroutine struct
15451572
// accesses for locals in `remap`.
15461573
// It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
@@ -1553,13 +1580,16 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15531580
storage_liveness,
15541581
always_live_locals,
15551582
suspension_points: Vec::new(),
1556-
old_ret_local,
15571583
discr_ty,
1584+
new_ret_local,
15581585
old_ret_ty,
15591586
old_yield_ty,
15601587
};
15611588
transform.visit_body(body);
15621589

1590+
// Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
1591+
transform.replace_local(RETURN_PLACE, new_ret_local, body);
1592+
15631593
// MIR parameters are not explicitly assigned-to when entering the MIR body.
15641594
// If we want to save their values inside the coroutine state, we need to do so explicitly.
15651595
let source_info = SourceInfo::outermost(body.span);

0 commit comments

Comments
 (0)