Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 100 additions & 41 deletions src/librustc_mir/transform/acs_propagate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
//! into one, but I can’t seem to get it just right yet, so we do the composing and decomposing
//! manually here.

use self::AcsLattice::*;

use rustc_data_structures::fnv::FnvHashMap;
use rustc::mir::repr::*;
use rustc::mir::visit::{MutVisitor, LvalueContext};
Expand All @@ -40,25 +42,49 @@ use pretty;

#[derive(PartialEq, Debug, Eq, Clone)]
enum Either<'tcx> {
Top,
Lvalue(Lvalue<'tcx>),
Const(Constant<'tcx>),
}

impl<'tcx> Lattice for Either<'tcx> {
fn bottom() -> Self { unimplemented!() }
#[derive(Debug, Clone)]
enum AcsLattice<'tcx> {
Bottom,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bottom element in this lattice is an empty HashMap. It is fine to have it this way, because HashMap::new() does not allocate, to my knowledge.

Similarly, I’m confused as to why the Top was removed from Either. I do not see how you could merge two different constants together and produce anything other than a Top.

/me shrugs

Copy link
Author

@gereeter gereeter Jun 6, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the lattice so that HashMap::new() is actually the top element of the lattice - I treat empty elements as top instead of bottom. You'll notice that join now does intersection instead of union. This is primarily because when we encounter an unknown function call, we need to mark everything as Top, because the function call could change anything. This situation can be improved with alias analysis and information about what the function might do, but the default needs to be Top. Similarly, when we enter the function, everything needs to be Top. Consider the following case:

fn foo(mut x: u32) {
    if /* random unoptimizable condition */ {
        x = 5;
    }
    println!("{}", x);
}

Previously, since the value stored for x was bottom, we would merge the fact bottom and the fact x = 5 at the end of the if statement, concluding that x was always 5. We would optimize to:

fn foo(mut x: u32) {
    if /* random unoptimizable condition */ {
        x = 5;
    }
    println!("{}", 5);
}

which is just wrong.

Note that, in terms of inspiration, while the original paper used a union-based lattice for constant propogation, GHC actually uses an intersection based lattice (see here).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I realized after I wrote that code that WBottom existed and that it would have been better to use that. However, I'm still planning to remove the requirement that lattices have bottom, again inspired by https://ghc.haskell.org/trac/ghc/wiki/Hoopl/Cleanup, and so I didn't bother to switch to a cleaner version.

Wrap(FnvHashMap<Lvalue<'tcx>, Either<'tcx>>)
}

impl<'tcx> Lattice for AcsLattice<'tcx> {
fn bottom() -> Self { Bottom }
fn join(&mut self, other: &Self) -> bool {
if self == other {
false
} else {
*self = Either::Top;
true
let other_map = match *other {
Bottom => return false,
Wrap(ref map) => map
};
let self_map = match *self {
Bottom => {
*self = Wrap(other_map.clone());
return true;
},
Wrap(ref mut map) => map
};

let mut changed = false;

for (k, v) in other_map {
let should_remove = if let Some(cur_v) = self_map.get(k) {
cur_v != v
} else {
false
};
if should_remove {
self_map.remove(k);
changed = true;
}
}

changed
}
}

type AcsLattice<'a> = FnvHashMap<Lvalue<'a>, Either<'a>>;

pub struct AcsPropagate;

impl Pass for AcsPropagate {}
Expand All @@ -79,33 +105,46 @@ impl<'tcx> MirPass<'tcx> for AcsPropagate {

struct AcsPropagateTransfer;

fn base_lvalue<'a, 'tcx>(mut lval: &'a Lvalue<'tcx>) -> &'a Lvalue<'tcx> {
while let &Lvalue::Projection(ref proj) = lval {
lval = &proj.base;
}
lval
}

impl<'tcx> Transfer<'tcx> for AcsPropagateTransfer {
type Lattice = AcsLattice<'tcx>;

fn stmt(&self, s: &Statement<'tcx>, mut lat: AcsLattice<'tcx>) -> AcsLattice<'tcx> {
fn stmt(&self, s: &Statement<'tcx>, lat: AcsLattice<'tcx>) -> AcsLattice<'tcx> {
let mut lat_map = match lat {
Bottom => FnvHashMap::default(),
Wrap(map) => map
};

let StatementKind::Assign(ref lval, ref rval) = s.kind;
if let &Lvalue::Projection(_) = lval {
let mut base = lval;
while let &Lvalue::Projection(ref proj) = base {
base = &proj.base;
}
lat.insert(base.clone(), Either::Top);
return lat;
lat_map.remove(base_lvalue(lval));
return Wrap(lat_map);
}

match *rval {
Rvalue::Use(Operand::Consume(ref nlval)) =>
lat.insert(lval.clone(), Either::Lvalue(nlval.clone())),
lat_map.insert(lval.clone(), Either::Lvalue(nlval.clone())),
Rvalue::Use(Operand::Constant(ref c)) =>
lat.insert(lval.clone(), Either::Const(c.clone())),
_ => lat.insert(lval.clone(), Either::Top)
lat_map.insert(lval.clone(), Either::Const(c.clone())),
_ => lat_map.remove(lval)
};
lat
Wrap(lat_map)
}

fn term(&self, t: &Terminator<'tcx>, mut lat: AcsLattice<'tcx>) -> Vec<AcsLattice<'tcx>> {
if let TerminatorKind::Call { destination: Some((ref dest, _)), .. } = t.kind {
lat.insert(dest.clone(), Either::Top);
match t.kind {
TerminatorKind::Call { .. } |
TerminatorKind::Drop { .. } => {
// FIXME: Be smarter here by using an alias analysis
lat = Wrap(FnvHashMap::default());
},
_ => { }
}

// FIXME: this should inspect the terminators and set their known values to constants. Esp.
Expand All @@ -122,22 +161,32 @@ struct AliasRewrite;
impl<'tcx> Rewrite<'tcx, AcsLattice<'tcx>> for AliasRewrite {
fn stmt(&self, s: &Statement<'tcx>, l: &AcsLattice<'tcx>, _: &mut CFG<'tcx>)
-> StatementChange<'tcx> {
let mut ns = s.clone();
let mut vis = RewriteAliasVisitor(&l, false);
vis.visit_statement(START_BLOCK, &mut ns);
if vis.1 { StatementChange::Statement(ns) } else { StatementChange::None }
if let Wrap(ref map) = *l {
let mut ns = s.clone();
let mut vis = RewriteAliasVisitor(map, false);
vis.visit_statement(START_BLOCK, &mut ns);
if vis.1 {
return StatementChange::Statement(ns);
}
}
StatementChange::None
}

fn term(&self, t: &Terminator<'tcx>, l: &AcsLattice<'tcx>, _: &mut CFG<'tcx>)
-> TerminatorChange<'tcx> {
let mut nt = t.clone();
let mut vis = RewriteAliasVisitor(&l, false);
vis.visit_terminator(START_BLOCK, &mut nt);
if vis.1 { TerminatorChange::Terminator(nt) } else { TerminatorChange::None }
if let Wrap(ref map) = *l {
let mut nt = t.clone();
let mut vis = RewriteAliasVisitor(map, false);
vis.visit_terminator(START_BLOCK, &mut nt);
if vis.1 {
return TerminatorChange::Terminator(nt);
}
}
TerminatorChange::None
}
}

struct RewriteAliasVisitor<'a, 'tcx: 'a>(&'a AcsLattice<'tcx>, bool);
struct RewriteAliasVisitor<'a, 'tcx: 'a>(&'a FnvHashMap<Lvalue<'tcx>, Either<'tcx>>, bool);
impl<'a, 'tcx> MutVisitor<'tcx> for RewriteAliasVisitor<'a, 'tcx> {
fn visit_lvalue(&mut self, lvalue: &mut Lvalue<'tcx>, context: LvalueContext) {
match context {
Expand All @@ -158,22 +207,32 @@ struct ConstRewrite;
impl<'tcx> Rewrite<'tcx, AcsLattice<'tcx>> for ConstRewrite {
fn stmt(&self, s: &Statement<'tcx>, l: &AcsLattice<'tcx>, _: &mut CFG<'tcx>)
-> StatementChange<'tcx> {
let mut ns = s.clone();
let mut vis = RewriteConstVisitor(&l, false);
vis.visit_statement(START_BLOCK, &mut ns);
if vis.1 { StatementChange::Statement(ns) } else { StatementChange::None }
if let Wrap(ref map) = *l {
let mut ns = s.clone();
let mut vis = RewriteConstVisitor(map, false);
vis.visit_statement(START_BLOCK, &mut ns);
if vis.1 {
return StatementChange::Statement(ns);
}
}
StatementChange::None
}

fn term(&self, t: &Terminator<'tcx>, l: &AcsLattice<'tcx>, _: &mut CFG<'tcx>)
-> TerminatorChange<'tcx> {
let mut nt = t.clone();
let mut vis = RewriteConstVisitor(&l, false);
vis.visit_terminator(START_BLOCK, &mut nt);
if vis.1 { TerminatorChange::Terminator(nt) } else { TerminatorChange::None }
if let Wrap(ref map) = *l {
let mut nt = t.clone();
let mut vis = RewriteConstVisitor(map, false);
vis.visit_terminator(START_BLOCK, &mut nt);
if vis.1 {
return TerminatorChange::Terminator(nt);
}
}
TerminatorChange::None
}
}

struct RewriteConstVisitor<'a, 'tcx: 'a>(&'a AcsLattice<'tcx>, bool);
struct RewriteConstVisitor<'a, 'tcx: 'a>(&'a FnvHashMap<Lvalue<'tcx>, Either<'tcx>>, bool);
impl<'a, 'tcx> MutVisitor<'tcx> for RewriteConstVisitor<'a, 'tcx> {
fn visit_operand(&mut self, op: &mut Operand<'tcx>) {
// To satisy borrow checker, modify `op` after inspecting it
Expand Down