Skip to content

Commit d924885

Browse files
authored
ZJIT: Create more ergonomic type profiling API (ruby#13339)
1 parent eead831 commit d924885

File tree

1 file changed

+81
-36
lines changed

1 file changed

+81
-36
lines changed

zjit/src/hir.rs

Lines changed: 81 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
use crate::{
77
cruby::*,
88
options::{get_option, DumpHIR},
9-
profile::{self, get_or_create_iseq_payload},
9+
profile::{get_or_create_iseq_payload, IseqPayload},
1010
state::ZJITState,
1111
cast::IntoUsize,
1212
};
@@ -673,6 +673,7 @@ pub struct Function {
673673
insn_types: Vec<Type>,
674674
blocks: Vec<Block>,
675675
entry_block: BlockId,
676+
profiles: Option<ProfileOracle>,
676677
}
677678

678679
impl Function {
@@ -685,6 +686,7 @@ impl Function {
685686
blocks: vec![Block::default()],
686687
entry_block: BlockId(0),
687688
param_types: vec![],
689+
profiles: None,
688690
}
689691
}
690692

@@ -994,6 +996,20 @@ impl Function {
994996
}
995997
}
996998

999+
/// Return the interpreter-profiled type of the HIR instruction at the given ISEQ instruction
1000+
/// index, if it is known. This historical type record is not a guarantee and must be checked
1001+
/// with a GuardType or similar.
1002+
fn profiled_type_of_at(&self, insn: InsnId, iseq_insn_idx: usize) -> Option<Type> {
1003+
let Some(ref profiles) = self.profiles else { return None };
1004+
let Some(entries) = profiles.types.get(&iseq_insn_idx) else { return None };
1005+
for &(entry_insn, entry_type) in entries {
1006+
if self.union_find.borrow().find_const(entry_insn) == self.union_find.borrow().find_const(insn) {
1007+
return Some(entry_type);
1008+
}
1009+
}
1010+
None
1011+
}
1012+
9971013
fn likely_is_fixnum(&self, val: InsnId, profiled_type: Type) -> bool {
9981014
return self.is_a(val, types::Fixnum) || profiled_type.is_subtype(types::Fixnum);
9991015
}
@@ -1003,20 +1019,16 @@ impl Function {
10031019
return self.push_insn(block, Insn::GuardType { val, guard_type: types::Fixnum, state });
10041020
}
10051021

1006-
fn arguments_likely_fixnums(&mut self, payload: &profile:: IseqPayload, left: InsnId, right: InsnId, state: InsnId) -> bool {
1007-
let mut left_profiled_type = types::BasicObject;
1008-
let mut right_profiled_type = types::BasicObject;
1022+
fn arguments_likely_fixnums(&mut self, left: InsnId, right: InsnId, state: InsnId) -> bool {
10091023
let frame_state = self.frame_state(state);
1010-
let insn_idx = frame_state.insn_idx;
1011-
if let Some([left_type, right_type]) = payload.get_operand_types(insn_idx as usize) {
1012-
left_profiled_type = *left_type;
1013-
right_profiled_type = *right_type;
1014-
}
1024+
let iseq_insn_idx = frame_state.insn_idx as usize;
1025+
let left_profiled_type = self.profiled_type_of_at(left, iseq_insn_idx).unwrap_or(types::BasicObject);
1026+
let right_profiled_type = self.profiled_type_of_at(right, iseq_insn_idx).unwrap_or(types::BasicObject);
10151027
self.likely_is_fixnum(left, left_profiled_type) && self.likely_is_fixnum(right, right_profiled_type)
10161028
}
10171029

1018-
fn try_rewrite_fixnum_op(&mut self, block: BlockId, orig_insn_id: InsnId, f: &dyn Fn(InsnId, InsnId) -> Insn, bop: u32, left: InsnId, right: InsnId, payload: &profile::IseqPayload, state: InsnId) {
1019-
if self.arguments_likely_fixnums(payload, left, right, state) {
1030+
fn try_rewrite_fixnum_op(&mut self, block: BlockId, orig_insn_id: InsnId, f: &dyn Fn(InsnId, InsnId) -> Insn, bop: u32, left: InsnId, right: InsnId, state: InsnId) {
1031+
if self.arguments_likely_fixnums(left, right, state) {
10201032
if bop == BOP_NEQ {
10211033
// For opt_neq, the interpreter checks that both neq and eq are unchanged.
10221034
self.push_insn(block, Insn::PatchPoint(Invariant::BOPRedefined { klass: INTEGER_REDEFINED_OP_FLAG, bop: BOP_EQ }));
@@ -1026,6 +1038,7 @@ impl Function {
10261038
let right = self.coerce_to_fixnum(block, right, state);
10271039
let result = self.push_insn(block, f(left, right));
10281040
self.make_equal_to(orig_insn_id, result);
1041+
self.insn_types[result.0] = self.infer_type(result);
10291042
} else {
10301043
self.push_insn_id(block, orig_insn_id);
10311044
}
@@ -1034,43 +1047,42 @@ impl Function {
10341047
/// Rewrite SendWithoutBlock opcodes into SendWithoutBlockDirect opcodes if we know the target
10351048
/// ISEQ statically. This removes run-time method lookups and opens the door for inlining.
10361049
fn optimize_direct_sends(&mut self) {
1037-
let payload = get_or_create_iseq_payload(self.iseq);
10381050
for block in self.rpo() {
10391051
let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
10401052
assert!(self.blocks[block.0].insns.is_empty());
10411053
for insn_id in old_insns {
10421054
match self.find(insn_id) {
10431055
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "+" && args.len() == 1 =>
1044-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumAdd { left, right, state }, BOP_PLUS, self_val, args[0], payload, state),
1056+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumAdd { left, right, state }, BOP_PLUS, self_val, args[0], state),
10451057
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "-" && args.len() == 1 =>
1046-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumSub { left, right, state }, BOP_MINUS, self_val, args[0], payload, state),
1058+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumSub { left, right, state }, BOP_MINUS, self_val, args[0], state),
10471059
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "*" && args.len() == 1 =>
1048-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMult { left, right, state }, BOP_MULT, self_val, args[0], payload, state),
1060+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMult { left, right, state }, BOP_MULT, self_val, args[0], state),
10491061
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "/" && args.len() == 1 =>
1050-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumDiv { left, right, state }, BOP_DIV, self_val, args[0], payload, state),
1062+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumDiv { left, right, state }, BOP_DIV, self_val, args[0], state),
10511063
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "%" && args.len() == 1 =>
1052-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMod { left, right, state }, BOP_MOD, self_val, args[0], payload, state),
1064+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMod { left, right, state }, BOP_MOD, self_val, args[0], state),
10531065
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "==" && args.len() == 1 =>
1054-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumEq { left, right }, BOP_EQ, self_val, args[0], payload, state),
1066+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumEq { left, right }, BOP_EQ, self_val, args[0], state),
10551067
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "!=" && args.len() == 1 =>
1056-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumNeq { left, right }, BOP_NEQ, self_val, args[0], payload, state),
1068+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumNeq { left, right }, BOP_NEQ, self_val, args[0], state),
10571069
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "<" && args.len() == 1 =>
1058-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLt { left, right }, BOP_LT, self_val, args[0], payload, state),
1070+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLt { left, right }, BOP_LT, self_val, args[0], state),
10591071
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "<=" && args.len() == 1 =>
1060-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLe { left, right }, BOP_LE, self_val, args[0], payload, state),
1072+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLe { left, right }, BOP_LE, self_val, args[0], state),
10611073
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == ">" && args.len() == 1 =>
1062-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGt { left, right }, BOP_GT, self_val, args[0], payload, state),
1074+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGt { left, right }, BOP_GT, self_val, args[0], state),
10631075
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == ">=" && args.len() == 1 =>
1064-
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGe { left, right }, BOP_GE, self_val, args[0], payload, state),
1076+
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGe { left, right }, BOP_GE, self_val, args[0], state),
10651077
Insn::SendWithoutBlock { mut self_val, call_info, cd, args, state } => {
10661078
let frame_state = self.frame_state(state);
10671079
let (klass, guard_equal_to) = if let Some(klass) = self.type_of(self_val).runtime_exact_ruby_class() {
10681080
// If we know the class statically, use it to fold the lookup at compile-time.
10691081
(klass, None)
10701082
} else {
10711083
// If we know that self is top-self from profile information, guard and use it to fold the lookup at compile-time.
1072-
match payload.get_operand_types(frame_state.insn_idx) {
1073-
Some([self_type, ..]) if self_type.is_top_self() => (self_type.exact_ruby_class().unwrap(), self_type.ruby_object()),
1084+
match self.profiled_type_of_at(self_val, frame_state.insn_idx) {
1085+
Some(self_type) if self_type.is_top_self() => (self_type.exact_ruby_class().unwrap(), self_type.ruby_object()),
10741086
_ => { self.push_insn_id(block, insn_id); continue; }
10751087
}
10761088
};
@@ -1130,7 +1142,6 @@ impl Function {
11301142
fn reduce_to_ccall(
11311143
fun: &mut Function,
11321144
block: BlockId,
1133-
payload: &profile::IseqPayload,
11341145
self_type: Type,
11351146
send: Insn,
11361147
send_insn_id: InsnId,
@@ -1142,7 +1153,6 @@ impl Function {
11421153
let call_info = unsafe { (*cd).ci };
11431154
let argc = unsafe { vm_ci_argc(call_info) };
11441155
let method_id = unsafe { rb_vm_ci_mid(call_info) };
1145-
let iseq_insn_idx = fun.frame_state(state).insn_idx;
11461156

11471157
// If we have info about the class of the receiver
11481158
//
@@ -1152,10 +1162,10 @@ impl Function {
11521162
let (recv_class, guard_type) = if let Some(klass) = self_type.runtime_exact_ruby_class() {
11531163
(klass, None)
11541164
} else {
1155-
payload.get_operand_types(iseq_insn_idx)
1156-
.and_then(|types| types.get(argc as usize))
1157-
.and_then(|recv_type| recv_type.exact_ruby_class().and_then(|class| Some((class, Some(recv_type.unspecialized())))))
1158-
.ok_or(())?
1165+
let iseq_insn_idx = fun.frame_state(state).insn_idx;
1166+
let Some(recv_type) = fun.profiled_type_of_at(self_val, iseq_insn_idx) else { return Err(()) };
1167+
let Some(recv_class) = recv_type.exact_ruby_class() else { return Err(()) };
1168+
(recv_class, Some(recv_type.unspecialized()))
11591169
};
11601170

11611171
// Do method lookup
@@ -1221,14 +1231,13 @@ impl Function {
12211231
Err(())
12221232
}
12231233

1224-
let payload = get_or_create_iseq_payload(self.iseq);
12251234
for block in self.rpo() {
12261235
let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
12271236
assert!(self.blocks[block.0].insns.is_empty());
12281237
for insn_id in old_insns {
12291238
if let send @ Insn::SendWithoutBlock { self_val, .. } = self.find(insn_id) {
12301239
let self_type = self.type_of(self_val);
1231-
if reduce_to_ccall(self, block, payload, self_type, send, insn_id).is_ok() {
1240+
if reduce_to_ccall(self, block, self_type, send, insn_id).is_ok() {
12321241
continue;
12331242
}
12341243
}
@@ -1598,7 +1607,7 @@ impl FrameState {
15981607
}
15991608

16001609
/// Get a stack operand at idx
1601-
fn stack_topn(&mut self, idx: usize) -> Result<InsnId, ParseError> {
1610+
fn stack_topn(&self, idx: usize) -> Result<InsnId, ParseError> {
16021611
let idx = self.stack.len() - idx - 1;
16031612
self.stack.get(idx).ok_or_else(|| ParseError::StackUnderflow(self.clone())).copied()
16041613
}
@@ -1717,8 +1726,42 @@ fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> {
17171726
Ok(())
17181727
}
17191728

1729+
/// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful
1730+
/// or correct to query from inside the optimizer. Instead, ProfileOracle provides an API to look
1731+
/// up profiled type information by HIR InsnId at a given ISEQ instruction.
1732+
#[derive(Debug)]
1733+
struct ProfileOracle {
1734+
payload: &'static IseqPayload,
1735+
/// types is a map from ISEQ instruction indices -> profiled type information at that ISEQ
1736+
/// instruction index. At a given ISEQ instruction, the interpreter has profiled the stack
1737+
/// operands to a given ISEQ instruction, and this list of pairs of (InsnId, Type) map that
1738+
/// profiling information into HIR instructions.
1739+
types: HashMap<usize, Vec<(InsnId, Type)>>,
1740+
}
1741+
1742+
impl ProfileOracle {
1743+
fn new(payload: &'static IseqPayload) -> Self {
1744+
Self { payload, types: Default::default() }
1745+
}
1746+
1747+
/// Map the interpreter-recorded types of the stack onto the HIR operands on our compile-time virtual stack
1748+
fn profile_stack(&mut self, state: &FrameState) {
1749+
let iseq_insn_idx = state.insn_idx;
1750+
let Some(operand_types) = self.payload.get_operand_types(iseq_insn_idx) else { return };
1751+
let entry = self.types.entry(iseq_insn_idx).or_insert_with(|| vec![]);
1752+
// operand_types is always going to be <= stack size (otherwise it would have an underflow
1753+
// at run-time) so use that to drive iteration.
1754+
for (idx, &insn_type) in operand_types.iter().rev().enumerate() {
1755+
let insn = state.stack_topn(idx).expect("Unexpected stack underflow in profiling");
1756+
entry.push((insn, insn_type))
1757+
}
1758+
}
1759+
}
1760+
17201761
/// Compile ISEQ into High-level IR
17211762
pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
1763+
let payload = get_or_create_iseq_payload(iseq);
1764+
let mut profiles = ProfileOracle::new(payload);
17221765
let mut fun = Function::new(iseq);
17231766
// Compute a map of PC->Block by finding jump targets
17241767
let jump_targets = compute_jump_targets(iseq);
@@ -1791,6 +1834,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
17911834
let pc = unsafe { rb_iseq_pc_at_idx(iseq, insn_idx) };
17921835
state.pc = pc;
17931836
let exit_state = state.clone();
1837+
profiles.profile_stack(&exit_state);
17941838

17951839
// try_into() call below is unfortunate. Maybe pick i32 instead of usize for opcodes.
17961840
let opcode: u32 = unsafe { rb_iseq_opcode_at_pc(iseq, pc) }
@@ -2061,6 +2105,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
20612105
None => {},
20622106
}
20632107

2108+
fun.profiles = Some(profiles);
20642109
Ok(fun)
20652110
}
20662111

@@ -3058,8 +3103,8 @@ mod opt_tests {
30583103
bb0():
30593104
PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
30603105
PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
3061-
v15:Fixnum[6] = Const Value(6)
3062-
Return v15
3106+
v14:Fixnum[6] = Const Value(6)
3107+
Return v14
30633108
"#]]);
30643109
}
30653110

0 commit comments

Comments
 (0)