Skip to content

Commit abcc628

Browse files
k0kubuntekknolagi
andauthored
ZJIT: Compile invokesuper with dynamic dispatch (ruby#14444)
Co-authored-by: Max Bernstein <[email protected]>
1 parent 1f76b09 commit abcc628

File tree

3 files changed

+156
-15
lines changed

3 files changed

+156
-15
lines changed

test/ruby/test_zjit.rb

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def test_nested_local_access
250250
}, call_threshold: 3, insns: [:getlocal, :setlocal, :getlocal_WC_0, :setlocal_WC_1]
251251
end
252252

253-
def test_read_local_written_by_children_iseqs
253+
def test_send_with_local_written_by_blockiseq
254254
assert_compiles '[1, 2]', %q{
255255
def test
256256
l1 = nil
@@ -343,6 +343,46 @@ def test(a, b = 2) = [a, b]
343343
}
344344
end
345345

346+
def test_invokesuper
347+
assert_compiles '[6, 60]', %q{
348+
class Foo
349+
def foo(a) = a + 1
350+
def bar(a) = a + 10
351+
end
352+
353+
class Bar < Foo
354+
def foo(a) = super(a) + 2
355+
def bar(a) = super + 20
356+
end
357+
358+
bar = Bar.new
359+
[bar.foo(3), bar.bar(30)]
360+
}
361+
end
362+
363+
def test_invokesuper_with_local_written_by_blockiseq
364+
# Using `assert_runs` because we don't compile invokeblock yet
365+
assert_runs '3', %q{
366+
class Foo
367+
def test
368+
yield
369+
end
370+
end
371+
372+
class Bar < Foo
373+
def test
374+
a = 1
375+
super do
376+
a += 2
377+
end
378+
a
379+
end
380+
end
381+
382+
Bar.new.test
383+
}
384+
end
385+
346386
def test_invokebuiltin
347387
omit 'Test fails at the moment due to not handling optional parameters'
348388
assert_compiles '["."]', %q{

zjit/src/codegen.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
367367
Insn::SendWithoutBlockDirect { cd, state, args, .. } if args.len() + 1 > C_ARG_OPNDS.len() => // +1 for self
368368
gen_send_without_block(jit, asm, *cd, &function.frame_state(*state)),
369369
Insn::SendWithoutBlockDirect { cme, iseq, self_val, args, state, .. } => gen_send_without_block_direct(cb, jit, asm, *cme, *iseq, opnd!(self_val), opnds!(args), &function.frame_state(*state)),
370+
&Insn::InvokeSuper { cd, blockiseq, state, .. } => gen_invokesuper(jit, asm, cd, blockiseq, &function.frame_state(state)),
370371
// Ensure we have enough room fit ec, self, and arguments
371372
// TODO remove this check when we have stack args (we can use Time.new to test it)
372373
Insn::InvokeBuiltin { bf, state, .. } if bf.argc + 2 > (C_ARG_OPNDS.len() as i32) => return Err(*state),
@@ -1076,6 +1077,38 @@ fn gen_send_without_block_direct(
10761077
ret
10771078
}
10781079

1080+
/// Compile a dynamic dispatch for `super`
1081+
fn gen_invokesuper(
1082+
jit: &mut JITState,
1083+
asm: &mut Assembler,
1084+
cd: *const rb_call_data,
1085+
blockiseq: IseqPtr,
1086+
state: &FrameState,
1087+
) -> lir::Opnd {
1088+
gen_incr_counter(asm, Counter::dynamic_send_count);
1089+
1090+
// Save PC and SP
1091+
gen_prepare_call_with_gc(asm, state);
1092+
gen_save_sp(asm, state.stack().len());
1093+
1094+
// Spill locals and stack
1095+
gen_spill_locals(jit, asm, state);
1096+
gen_spill_stack(jit, asm, state);
1097+
1098+
asm_comment!(asm, "call super with dynamic dispatch");
1099+
unsafe extern "C" {
1100+
fn rb_vm_invokesuper(ec: EcPtr, cfp: CfpPtr, cd: VALUE, blockiseq: IseqPtr) -> VALUE;
1101+
}
1102+
let ret = asm.ccall(
1103+
rb_vm_invokesuper as *const u8,
1104+
vec![EC, CFP, (cd as usize).into(), VALUE(blockiseq as usize).into()],
1105+
);
1106+
// TODO: Add a PatchPoint here that can side-exit the function if the callee messed with
1107+
// the frame's locals
1108+
1109+
ret
1110+
}
1111+
10791112
/// Compile a string resurrection
10801113
fn gen_string_copy(asm: &mut Assembler, recv: Opnd, chilled: bool, state: &FrameState) -> Opnd {
10811114
// TODO: split rb_ec_str_resurrect into separate functions

zjit/src/hir.rs

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -565,10 +565,13 @@ pub enum Insn {
565565
/// `name` is for printing purposes only
566566
CCall { cfun: *const u8, args: Vec<InsnId>, name: ID, return_type: Type, elidable: bool },
567567

568-
/// Send without block with dynamic dispatch
568+
/// Un-optimized fallback implementation (dynamic dispatch) for send-ish instructions
569569
/// Ignoring keyword arguments etc for now
570570
SendWithoutBlock { self_val: InsnId, cd: *const rb_call_data, args: Vec<InsnId>, state: InsnId },
571571
Send { self_val: InsnId, cd: *const rb_call_data, blockiseq: IseqPtr, args: Vec<InsnId>, state: InsnId },
572+
InvokeSuper { self_val: InsnId, cd: *const rb_call_data, blockiseq: IseqPtr, args: Vec<InsnId>, state: InsnId },
573+
574+
/// Optimized ISEQ call
572575
SendWithoutBlockDirect {
573576
self_val: InsnId,
574577
cd: *const rb_call_data,
@@ -817,6 +820,13 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
817820
}
818821
Ok(())
819822
}
823+
Insn::InvokeSuper { self_val, blockiseq, args, .. } => {
824+
write!(f, "InvokeSuper {self_val}, {:p}", self.ptr_map.map_ptr(blockiseq))?;
825+
for arg in args {
826+
write!(f, ", {arg}")?;
827+
}
828+
Ok(())
829+
}
820830
Insn::InvokeBuiltin { bf, args, .. } => {
821831
write!(f, "InvokeBuiltin {}", unsafe { CStr::from_ptr(bf.name) }.to_str().unwrap())?;
822832
for arg in args {
@@ -1310,6 +1320,13 @@ impl Function {
13101320
args: find_vec!(args),
13111321
state,
13121322
},
1323+
&InvokeSuper { self_val, cd, blockiseq, ref args, state } => InvokeSuper {
1324+
self_val: find!(self_val),
1325+
cd,
1326+
blockiseq,
1327+
args: find_vec!(args),
1328+
state,
1329+
},
13131330
&InvokeBuiltin { bf, ref args, state, return_type } => InvokeBuiltin { bf, args: find_vec!(args), state, return_type },
13141331
&ArrayDup { val, state } => ArrayDup { val: find!(val), state },
13151332
&HashDup { val, state } => HashDup { val: find!(val), state },
@@ -1418,6 +1435,7 @@ impl Function {
14181435
Insn::SendWithoutBlock { .. } => types::BasicObject,
14191436
Insn::SendWithoutBlockDirect { .. } => types::BasicObject,
14201437
Insn::Send { .. } => types::BasicObject,
1438+
Insn::InvokeSuper { .. } => types::BasicObject,
14211439
Insn::InvokeBuiltin { return_type, .. } => return_type.unwrap_or(types::BasicObject),
14221440
Insn::Defined { pushval, .. } => Type::from_value(*pushval).union(types::NilClass),
14231441
Insn::DefinedIvar { .. } => types::BasicObject,
@@ -2206,7 +2224,8 @@ impl Function {
22062224
}
22072225
&Insn::Send { self_val, ref args, state, .. }
22082226
| &Insn::SendWithoutBlock { self_val, ref args, state, .. }
2209-
| &Insn::SendWithoutBlockDirect { self_val, ref args, state, .. } => {
2227+
| &Insn::SendWithoutBlockDirect { self_val, ref args, state, .. }
2228+
| &Insn::InvokeSuper { self_val, ref args, state, .. } => {
22102229
worklist.push_back(self_val);
22112230
worklist.extend(args);
22122231
worklist.push_back(state);
@@ -2830,14 +2849,14 @@ fn insn_idx_at_offset(idx: u32, offset: i64) -> u32 {
28302849

28312850
struct BytecodeInfo {
28322851
jump_targets: Vec<u32>,
2833-
has_send: bool,
2852+
has_blockiseq: bool,
28342853
}
28352854

28362855
fn compute_bytecode_info(iseq: *const rb_iseq_t) -> BytecodeInfo {
28372856
let iseq_size = unsafe { get_iseq_encoded_size(iseq) };
28382857
let mut insn_idx = 0;
28392858
let mut jump_targets = HashSet::new();
2840-
let mut has_send = false;
2859+
let mut has_blockiseq = false;
28412860
while insn_idx < iseq_size {
28422861
// Get the current pc and opcode
28432862
let pc = unsafe { rb_iseq_pc_at_idx(iseq, insn_idx) };
@@ -2861,13 +2880,18 @@ fn compute_bytecode_info(iseq: *const rb_iseq_t) -> BytecodeInfo {
28612880
jump_targets.insert(insn_idx);
28622881
}
28632882
}
2864-
YARVINSN_send => has_send = true,
2883+
YARVINSN_send | YARVINSN_invokesuper => {
2884+
let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq();
2885+
if !blockiseq.is_null() {
2886+
has_blockiseq = true;
2887+
}
2888+
}
28652889
_ => {}
28662890
}
28672891
}
28682892
let mut result = jump_targets.into_iter().collect::<Vec<_>>();
28692893
result.sort();
2870-
BytecodeInfo { jump_targets: result, has_send }
2894+
BytecodeInfo { jump_targets: result, has_blockiseq }
28712895
}
28722896

28732897
#[derive(Debug, PartialEq, Clone, Copy)]
@@ -2984,7 +3008,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
29843008
let mut profiles = ProfileOracle::new(payload);
29853009
let mut fun = Function::new(iseq);
29863010
// Compute a map of PC->Block by finding jump targets
2987-
let BytecodeInfo { jump_targets, has_send } = compute_bytecode_info(iseq);
3011+
let BytecodeInfo { jump_targets, has_blockiseq } = compute_bytecode_info(iseq);
29883012
let mut insn_idx_to_block = HashMap::new();
29893013
for insn_idx in jump_targets {
29903014
if insn_idx == 0 {
@@ -3321,7 +3345,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
33213345
}
33223346
YARVINSN_getlocal_WC_0 => {
33233347
let ep_offset = get_arg(pc, 0).as_u32();
3324-
if iseq_type == ISEQ_TYPE_EVAL || has_send {
3348+
if iseq_type == ISEQ_TYPE_EVAL || has_blockiseq {
33253349
// On eval, the locals are always on the heap, so read the local using EP.
33263350
let val = fun.push_insn(block, Insn::GetLocal { ep_offset, level: 0 });
33273351
state.setlocal(ep_offset, val);
@@ -3341,7 +3365,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
33413365
let ep_offset = get_arg(pc, 0).as_u32();
33423366
let val = state.stack_pop()?;
33433367
state.setlocal(ep_offset, val);
3344-
if iseq_type == ISEQ_TYPE_EVAL || has_send {
3368+
if iseq_type == ISEQ_TYPE_EVAL || has_blockiseq {
33453369
// On eval, the locals are always on the heap, so write the local using EP.
33463370
fun.push_insn(block, Insn::SetLocal { val, ep_offset, level: 0 });
33473371
}
@@ -3521,6 +3545,34 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
35213545
state.setlocal(ep_offset, val);
35223546
}
35233547
}
3548+
YARVINSN_invokesuper => {
3549+
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
3550+
let call_info = unsafe { rb_get_call_data_ci(cd) };
3551+
if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) } & !VM_CALL_SUPER & !VM_CALL_ZSUPER) {
3552+
// Unknown call type; side-exit into the interpreter
3553+
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
3554+
fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) });
3555+
break; // End the block
3556+
}
3557+
let argc = unsafe { vm_ci_argc((*cd).ci) };
3558+
let args = state.stack_pop_n(argc as usize)?;
3559+
let recv = state.stack_pop()?;
3560+
let blockiseq: IseqPtr = get_arg(pc, 1).as_ptr();
3561+
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
3562+
let result = fun.push_insn(block, Insn::InvokeSuper { self_val: recv, cd, blockiseq, args, state: exit_id });
3563+
state.stack_push(result);
3564+
3565+
if !blockiseq.is_null() {
3566+
// Reload locals that may have been modified by the blockiseq.
3567+
// TODO: Avoid reloading locals that are not referenced by the blockiseq
3568+
// or not used after this. Max thinks we could eventually DCE them.
3569+
for local_idx in 0..state.locals.len() {
3570+
let ep_offset = local_idx_to_ep_offset(iseq, local_idx) as u32;
3571+
let val = fun.push_insn(block, Insn::GetLocal { ep_offset, level: 0 });
3572+
state.setlocal(ep_offset, val);
3573+
}
3574+
}
3575+
}
35243576
YARVINSN_getglobal => {
35253577
let id = ID(get_arg(pc, 0).as_u64());
35263578
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
@@ -4971,7 +5023,6 @@ mod tests {
49715023
assert_snapshot!(hir_string("test"), @r"
49725024
fn test@<compiled>:2:
49735025
bb0(v0:BasicObject, v1:BasicObject):
4974-
v5:BasicObject = GetLocal l0, EP@3
49755026
SideExit UnhandledCallType(BlockArg)
49765027
");
49775028
}
@@ -5004,26 +5055,43 @@ mod tests {
50045055
// TODO(max): Figure out how to generate a call with TAILCALL flag
50055056

50065057
#[test]
5007-
fn test_cant_compile_super() {
5058+
fn test_compile_super() {
50085059
eval("
50095060
def test = super()
50105061
");
50115062
assert_snapshot!(hir_string("test"), @r"
50125063
fn test@<compiled>:2:
50135064
bb0(v0:BasicObject):
5014-
SideExit UnhandledYARVInsn(invokesuper)
5065+
v5:BasicObject = InvokeSuper v0, 0x1000
5066+
CheckInterrupts
5067+
Return v5
50155068
");
50165069
}
50175070

50185071
#[test]
5019-
fn test_cant_compile_zsuper() {
5072+
fn test_compile_zsuper() {
50205073
eval("
50215074
def test = super
50225075
");
50235076
assert_snapshot!(hir_string("test"), @r"
50245077
fn test@<compiled>:2:
50255078
bb0(v0:BasicObject):
5026-
SideExit UnhandledYARVInsn(invokesuper)
5079+
v5:BasicObject = InvokeSuper v0, 0x1000
5080+
CheckInterrupts
5081+
Return v5
5082+
");
5083+
}
5084+
5085+
#[test]
5086+
fn test_cant_compile_super_nil_blockarg() {
5087+
eval("
5088+
def test = super(&nil)
5089+
");
5090+
assert_snapshot!(hir_string("test"), @r"
5091+
fn test@<compiled>:2:
5092+
bb0(v0:BasicObject):
5093+
v4:NilClass = Const Value(nil)
5094+
SideExit UnhandledCallType(BlockArg)
50275095
");
50285096
}
50295097

0 commit comments

Comments
 (0)