Skip to content

Commit 89849f3

Browse files
committed
ZJIT: Support JIT-to-JIT calls to callees with optional parameters
* Correct JIT entry points for optionals so each optional start with nil before their initialization routine runs. Establish `jit_entry_points[filled_opts_num]` gives the appropriate entry point * Correct number of HIR block parameters for each JIT entry point * Entry points that share the same ISEQ PC get separate entries since they start with different state. No more deduplication. * Reject post parameters. Was hidden behind check for optionals. * Make sure to visit every BB in iseq_to_hir(). Some wasn't visited when the initialization routine for an optional terminates the block in a `SideExit`. Remove the now impossible `FailedOptionalArguments`.
1 parent 5d35e24 commit 89849f3

File tree

5 files changed

+287
-111
lines changed

5 files changed

+287
-111
lines changed

zjit/src/codegen.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,10 +1264,21 @@ fn gen_send_without_block_direct(
12641264

12651265
// Set up arguments
12661266
let mut c_args = vec![recv];
1267-
c_args.extend(args);
1267+
c_args.extend(&args);
1268+
1269+
let num_optionals_passed = if unsafe { get_iseq_flags_has_opt(iseq) } {
1270+
// See vm_call_iseq_setup_normal_opt_start in vm_inshelper.c
1271+
let lead_num = unsafe { get_iseq_body_param_lead_num(iseq) } as u32;
1272+
let opt_num = unsafe { get_iseq_body_param_opt_num(iseq) } as u32;
1273+
assert!(args.len() as u32 <= lead_num + opt_num);
1274+
let num_optionals_passed = args.len() as u32 - lead_num;
1275+
num_optionals_passed
1276+
} else {
1277+
0
1278+
};
12681279

12691280
// Make a method call. The target address will be rewritten once compiled.
1270-
let iseq_call = IseqCall::new(iseq);
1281+
let iseq_call = IseqCall::new(iseq, num_optionals_passed);
12711282
let dummy_ptr = cb.get_write_ptr().raw_ptr(cb);
12721283
jit.iseq_calls.push(iseq_call.clone());
12731284
let ret = asm.ccall_with_iseq_call(dummy_ptr, c_args, &iseq_call);
@@ -2129,7 +2140,8 @@ c_callable! {
21292140
// function_stub_hit_body() may allocate and call gc_validate_pc(), so we always set PC.
21302141
let iseq_call = unsafe { Rc::from_raw(iseq_call_ptr as *const IseqCall) };
21312142
let iseq = iseq_call.iseq.get();
2132-
let pc = unsafe { rb_iseq_pc_at_idx(iseq, 0) }; // TODO: handle opt_pc once supported
2143+
let entry_insn_idxs = crate::hir::jit_entry_insns(iseq);
2144+
let pc = unsafe { rb_iseq_pc_at_idx(iseq, entry_insn_idxs[iseq_call.jit_entry_idx.to_usize()]) };
21332145
unsafe { rb_set_cfp_pc(cfp, pc) };
21342146

21352147
// JIT-to-JIT calls don't set SP or fill nils to uninitialized (non-argument) locals.
@@ -2189,13 +2201,8 @@ fn function_stub_hit_body(cb: &mut CodeBlock, iseq_call: &IseqCallRef) -> Result
21892201
debug!("{err:?}: gen_iseq failed: {}", iseq_get_location(iseq_call.iseq.get(), 0));
21902202
})?;
21912203

2192-
// We currently don't support JIT-to-JIT calls for ISEQs with optional arguments.
2193-
// So we only need to use jit_entry_ptrs[0] for now. TODO(Shopify/ruby#817): Support optional arguments.
2194-
let Some(&jit_entry_ptr) = jit_entry_ptrs.first() else {
2195-
return Err(CompileError::JitToJitOptional)
2196-
};
2197-
21982204
// Update the stub to call the code pointer
2205+
let jit_entry_ptr = jit_entry_ptrs[iseq_call.jit_entry_idx.to_usize()];
21992206
let code_addr = jit_entry_ptr.raw_ptr(cb);
22002207
let iseq = iseq_call.iseq.get();
22012208
iseq_call.regenerate(cb, |asm| {
@@ -2407,7 +2414,7 @@ impl Assembler {
24072414

24082415
/// Store info about a JIT entry point
24092416
pub struct JITEntry {
2410-
/// Index that corresponds to jit_entry_insns()
2417+
/// Index that corresponds to [crate::hir::jit_entry_insns]
24112418
jit_entry_idx: usize,
24122419
/// Position where the entry point starts
24132420
start_addr: Cell<Option<CodePtr>>,
@@ -2430,6 +2437,9 @@ pub struct IseqCall {
24302437
/// Callee ISEQ that start_addr jumps to
24312438
pub iseq: Cell<IseqPtr>,
24322439

2440+
/// Index that corresponds to [crate::hir::jit_entry_insns]
2441+
jit_entry_idx: u32,
2442+
24332443
/// Position where the call instruction starts
24342444
start_addr: Cell<Option<CodePtr>>,
24352445

@@ -2441,11 +2451,12 @@ pub type IseqCallRef = Rc<IseqCall>;
24412451

24422452
impl IseqCall {
24432453
/// Allocate a new IseqCall
2444-
fn new(iseq: IseqPtr) -> IseqCallRef {
2454+
fn new(iseq: IseqPtr, jit_entry_idx: u32) -> IseqCallRef {
24452455
let iseq_call = IseqCall {
24462456
iseq: Cell::new(iseq),
24472457
start_addr: Cell::new(None),
24482458
end_addr: Cell::new(None),
2459+
jit_entry_idx,
24492460
};
24502461
Rc::new(iseq_call)
24512462
}

zjit/src/hir.rs

Lines changed: 79 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,7 +1468,7 @@ fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq
14681468

14691469
use Counter::*;
14701470
if unsafe { rb_get_iseq_flags_has_rest(iseq) } { count_failure(complex_arg_pass_param_rest) }
1471-
if unsafe { rb_get_iseq_flags_has_opt(iseq) } { count_failure(complex_arg_pass_param_opt) }
1471+
if unsafe { rb_get_iseq_flags_has_post(iseq) } { count_failure(complex_arg_pass_param_post) }
14721472
if unsafe { rb_get_iseq_flags_has_kw(iseq) } { count_failure(complex_arg_pass_param_kw) }
14731473
if unsafe { rb_get_iseq_flags_has_kwrest(iseq) } { count_failure(complex_arg_pass_param_kwrest) }
14741474
if unsafe { rb_get_iseq_flags_has_block(iseq) } { count_failure(complex_arg_pass_param_block) }
@@ -1511,7 +1511,8 @@ pub struct Function {
15111511
blocks: Vec<Block>,
15121512
/// Entry block for the interpreter
15131513
entry_block: BlockId,
1514-
/// Entry block for JIT-to-JIT calls
1514+
/// Entry block for JIT-to-JIT calls. Length will be `opt_num+1`, for callers
1515+
/// fulfilling `(0..=opt_num)` optional parameters.
15151516
jit_entry_blocks: Vec<BlockId>,
15161517
profiles: Option<ProfileOracle>,
15171518
}
@@ -2070,9 +2071,8 @@ impl Function {
20702071
for jit_entry_block in self.jit_entry_blocks.iter() {
20712072
let entry_params = self.blocks[jit_entry_block.0].params.iter();
20722073
let param_types = self.param_types.iter();
2073-
assert_eq!(
2074-
entry_params.len(),
2075-
param_types.len(),
2074+
assert!(
2075+
param_types.len() >= entry_params.len(),
20762076
"param types should be initialized before type inference",
20772077
);
20782078
for (param, param_type) in std::iter::zip(entry_params, param_types) {
@@ -4296,7 +4296,8 @@ fn insn_idx_at_offset(idx: u32, offset: i64) -> u32 {
42964296
}
42974297

42984298
/// List of insn_idx that starts a JIT entry block
4299-
fn jit_entry_insns(iseq: IseqPtr) -> Vec<u32> {
4299+
pub fn jit_entry_insns(iseq: IseqPtr) -> Vec<u32> {
4300+
// TODO(alan): Make an iterator type for this instead of copying all of the opt_table each call
43004301
let opt_num = unsafe { get_iseq_body_param_opt_num(iseq) };
43014302
if opt_num > 0 {
43024303
let mut result = vec![];
@@ -4306,10 +4307,6 @@ fn jit_entry_insns(iseq: IseqPtr) -> Vec<u32> {
43064307
let insn_idx = unsafe { opt_table.offset(opt_idx).read().as_u32() };
43074308
result.push(insn_idx);
43084309
}
4309-
4310-
// Deduplicate entries with HashSet since opt_table may have duplicated entries, e.g. proc { |a=a| a }
4311-
result.sort();
4312-
result.dedup();
43134310
result
43144311
} else {
43154312
vec![0]
@@ -4375,7 +4372,6 @@ pub enum ParseError {
43754372
StackUnderflow(FrameState),
43764373
MalformedIseq(u32), // insn_idx into iseq_encoded
43774374
Validation(ValidationError),
4378-
FailedOptionalArguments,
43794375
NotAllowed,
43804376
}
43814377

@@ -4456,28 +4452,45 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
44564452
// Compute a map of PC->Block by finding jump targets
44574453
let jit_entry_insns = jit_entry_insns(iseq);
44584454
let BytecodeInfo { jump_targets, has_blockiseq } = compute_bytecode_info(iseq, &jit_entry_insns);
4455+
4456+
// Make all empty basic blocks. The ordering of the BBs matters as it is taken as a schedule
4457+
// in the backend without a scheduling pass. TODO: Higher quality scheduling during lowering.
44594458
let mut insn_idx_to_block = HashMap::new();
4459+
// Make blocks for optionals first, and put them right next to their JIT entrypoint
4460+
for insn_idx in jit_entry_insns.iter().copied() {
4461+
let jit_entry_block = fun.new_block(insn_idx);
4462+
fun.jit_entry_blocks.push(jit_entry_block);
4463+
insn_idx_to_block.entry(insn_idx).or_insert_with(|| fun.new_block(insn_idx));
4464+
}
4465+
// Make blocks for the rest of the jump targets
44604466
for insn_idx in jump_targets {
4461-
// Prepend a JIT entry block if it's a jit_entry_insn.
4462-
// compile_entry_block() assumes that a JIT entry block jumps to the next block.
4463-
if jit_entry_insns.contains(&insn_idx) {
4464-
let jit_entry_block = fun.new_block(insn_idx);
4465-
fun.jit_entry_blocks.push(jit_entry_block);
4466-
}
4467-
insn_idx_to_block.insert(insn_idx, fun.new_block(insn_idx));
4467+
insn_idx_to_block.entry(insn_idx).or_insert_with(|| fun.new_block(insn_idx));
4468+
}
4469+
// Done, drop `mut`.
4470+
let insn_idx_to_block = insn_idx_to_block;
4471+
4472+
// Compile an entry_block for the interpreter
4473+
compile_entry_block(&mut fun, jit_entry_insns.as_slice(), &insn_idx_to_block);
4474+
4475+
// Compile all JIT-to-JIT entry blocks
4476+
for (jit_entry_idx, insn_idx) in jit_entry_insns.iter().enumerate() {
4477+
let target_block = insn_idx_to_block.get(insn_idx)
4478+
.copied()
4479+
.expect("we make a block for each jump target and \
4480+
each entry in the ISEQ opt_table is a jump target");
4481+
compile_jit_entry_block(&mut fun, jit_entry_idx, target_block);
44684482
}
44694483

44704484
// Check if the EP is escaped for the ISEQ from the beginning. We give up
44714485
// optimizing locals in that case because they're shared with other frames.
44724486
let ep_escaped = iseq_escapes_ep(iseq);
44734487

4474-
// Compile an entry_block for the interpreter
4475-
compile_entry_block(&mut fun, &jit_entry_insns);
4476-
4477-
// Iteratively fill out basic blocks using a queue
4488+
// Iteratively fill out basic blocks using a queue.
44784489
// TODO(max): Basic block arguments at edges
44794490
let mut queue = VecDeque::new();
4480-
queue.push_back((FrameState::new(iseq), insn_idx_to_block[&0], /*insn_idx=*/0, /*local_inval=*/false));
4491+
for &insn_idx in jit_entry_insns.iter() {
4492+
queue.push_back((FrameState::new(iseq), insn_idx_to_block[&insn_idx], /*insn_idx=*/insn_idx, /*local_inval=*/false));
4493+
}
44814494

44824495
// Keep compiling blocks until the queue becomes empty
44834496
let mut visited = HashSet::new();
@@ -4487,16 +4500,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
44874500
if visited.contains(&block) { continue; }
44884501
visited.insert(block);
44894502

4490-
// Compile a JIT entry to the block if it's a jit_entry_insn
4491-
if let Some(block_idx) = jit_entry_insns.iter().position(|&idx| idx == insn_idx) {
4492-
compile_jit_entry_block(&mut fun, block_idx, block);
4493-
}
4494-
44954503
// Load basic block params first
44964504
let self_param = fun.push_insn(block, Insn::Param);
44974505
let mut state = {
44984506
let mut result = FrameState::new(iseq);
4499-
let local_size = if insn_idx == 0 { num_locals(iseq) } else { incoming_state.locals.len() };
4507+
let local_size = if jit_entry_insns.contains(&insn_idx) { num_locals(iseq) } else { incoming_state.locals.len() };
45004508
for _ in 0..local_size {
45014509
result.locals.push(fun.push_insn(block, Insn::Param));
45024510
}
@@ -5348,14 +5356,6 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
53485356
}
53495357
}
53505358

5351-
// Bail out if there's a JIT entry block whose target block was not compiled
5352-
// due to an unsupported instruction in the middle of the ISEQ.
5353-
for jit_entry_block in fun.jit_entry_blocks.iter() {
5354-
if fun.blocks[jit_entry_block.0].insns.is_empty() {
5355-
return Err(ParseError::FailedOptionalArguments);
5356-
}
5357-
}
5358-
53595359
fun.set_param_types();
53605360
fun.infer_types();
53615361

@@ -5375,44 +5375,46 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
53755375
}
53765376

53775377
/// Compile an entry_block for the interpreter
5378-
fn compile_entry_block(fun: &mut Function, jit_entry_insns: &[u32]) {
5378+
fn compile_entry_block(fun: &mut Function, jit_entry_insns: &[u32], insn_idx_to_block: &HashMap<u32, BlockId>) {
53795379
let entry_block = fun.entry_block;
5380-
fun.push_insn(entry_block, Insn::EntryPoint { jit_entry_idx: None });
5381-
5382-
// Prepare entry_state with basic block params
5383-
let (self_param, entry_state) = compile_entry_state(fun, entry_block);
5384-
5385-
// Jump to target blocks
5380+
let (self_param, entry_state) = compile_entry_state(fun);
53865381
let mut pc: Option<InsnId> = None;
5387-
let &last_entry_insn = jit_entry_insns.last().unwrap();
5388-
for (jit_entry_block_idx, &jit_entry_insn) in jit_entry_insns.iter().enumerate() {
5389-
let jit_entry_block = fun.jit_entry_blocks[jit_entry_block_idx];
5390-
let target_block = BlockId(jit_entry_block.0 + 1); // jit_entry_block precedes the jump target block
5391-
5392-
if jit_entry_insn == last_entry_insn {
5393-
// If it's the last possible entry, jump to the target_block without checking PC
5394-
fun.push_insn(entry_block, Insn::Jump(BranchEdge { target: target_block, args: entry_state.as_args(self_param) }));
5395-
} else {
5396-
// Otherwise, jump to the target_block only if PC matches.
5397-
let pc = pc.unwrap_or_else(|| {
5398-
let insn_id = fun.push_insn(entry_block, Insn::LoadPC);
5399-
pc = Some(insn_id);
5400-
insn_id
5401-
});
5402-
let expected_pc = fun.push_insn(entry_block, Insn::Const {
5403-
val: Const::CPtr(unsafe { rb_iseq_pc_at_idx(fun.iseq, jit_entry_insn) } as *const u8),
5404-
});
5405-
let test_id = fun.push_insn(entry_block, Insn::IsBitEqual { left: pc, right: expected_pc });
5406-
fun.push_insn(entry_block, Insn::IfTrue {
5407-
val: test_id,
5408-
target: BranchEdge { target: target_block, args: entry_state.as_args(self_param) },
5409-
});
5382+
let &all_opts_passed_insn_idx = jit_entry_insns.last().unwrap();
5383+
5384+
// Check-and-jump for each missing optional PC
5385+
for &jit_entry_insn in jit_entry_insns.iter() {
5386+
if jit_entry_insn == all_opts_passed_insn_idx {
5387+
continue;
54105388
}
5389+
let target_block = insn_idx_to_block.get(&jit_entry_insn)
5390+
.copied()
5391+
.expect("we make a block for each jump target and \
5392+
each entry in the ISEQ opt_table is a jump target");
5393+
// Load PC once at the start of the block, shared among all cases
5394+
let pc = *pc.get_or_insert_with(|| fun.push_insn(entry_block, Insn::LoadPC));
5395+
let expected_pc = fun.push_insn(entry_block, Insn::Const {
5396+
val: Const::CPtr(unsafe { rb_iseq_pc_at_idx(fun.iseq, jit_entry_insn) } as *const u8),
5397+
});
5398+
let test_id = fun.push_insn(entry_block, Insn::IsBitEqual { left: pc, right: expected_pc });
5399+
fun.push_insn(entry_block, Insn::IfTrue {
5400+
val: test_id,
5401+
target: BranchEdge { target: target_block, args: entry_state.as_args(self_param) },
5402+
});
54115403
}
5404+
5405+
// Terminate the block with a jump to the block with all optionals passed
5406+
let target_block = insn_idx_to_block.get(&all_opts_passed_insn_idx)
5407+
.copied()
5408+
.expect("we make a block for each jump target and \
5409+
each entry in the ISEQ opt_table is a jump target");
5410+
fun.push_insn(entry_block, Insn::Jump(BranchEdge { target: target_block, args: entry_state.as_args(self_param) }));
54125411
}
54135412

54145413
/// Compile initial locals for an entry_block for the interpreter
5415-
fn compile_entry_state(fun: &mut Function, entry_block: BlockId) -> (InsnId, FrameState) {
5414+
fn compile_entry_state(fun: &mut Function) -> (InsnId, FrameState) {
5415+
let entry_block = fun.entry_block;
5416+
fun.push_insn(entry_block, Insn::EntryPoint { jit_entry_idx: None });
5417+
54165418
let iseq = fun.iseq;
54175419
let param_size = unsafe { get_iseq_body_param_size(iseq) }.to_usize();
54185420
let rest_param_idx = iseq_rest_param_idx(iseq);
@@ -5438,21 +5440,27 @@ fn compile_jit_entry_block(fun: &mut Function, jit_entry_idx: usize, target_bloc
54385440
fun.push_insn(jit_entry_block, Insn::EntryPoint { jit_entry_idx: Some(jit_entry_idx) });
54395441

54405442
// Prepare entry_state with basic block params
5441-
let (self_param, entry_state) = compile_jit_entry_state(fun, jit_entry_block);
5443+
let (self_param, entry_state) = compile_jit_entry_state(fun, jit_entry_block, jit_entry_idx);
54425444

54435445
// Jump to target_block
54445446
fun.push_insn(jit_entry_block, Insn::Jump(BranchEdge { target: target_block, args: entry_state.as_args(self_param) }));
54455447
}
54465448

54475449
/// Compile params and initial locals for a jit_entry_block
5448-
fn compile_jit_entry_state(fun: &mut Function, jit_entry_block: BlockId) -> (InsnId, FrameState) {
5450+
fn compile_jit_entry_state(fun: &mut Function, jit_entry_block: BlockId, jit_entry_idx: usize) -> (InsnId, FrameState) {
54495451
let iseq = fun.iseq;
54505452
let param_size = unsafe { get_iseq_body_param_size(iseq) }.to_usize();
5453+
let opt_num: usize = unsafe { get_iseq_body_param_opt_num(iseq) }.try_into().expect("iseq param opt_num >= 0");
5454+
let lead_num: usize = unsafe { get_iseq_body_param_lead_num(iseq) }.try_into().expect("iseq param lead_num >= 0");
5455+
let passed_opt_num = jit_entry_idx;
54515456

54525457
let self_param = fun.push_insn(jit_entry_block, Insn::Param);
54535458
let mut entry_state = FrameState::new(iseq);
54545459
for local_idx in 0..num_locals(iseq) {
5455-
if local_idx < param_size {
5460+
if (lead_num + passed_opt_num..lead_num + opt_num).contains(&local_idx) {
5461+
// Omitted optionals are locals, so they start as nils before their code run
5462+
entry_state.locals.push(fun.push_insn(jit_entry_block, Insn::Const { val: Const::Value(Qnil) }));
5463+
} else if local_idx < param_size {
54565464
entry_state.locals.push(fun.push_insn(jit_entry_block, Insn::Param));
54575465
} else {
54585466
entry_state.locals.push(fun.push_insn(jit_entry_block, Insn::Const { val: Const::Value(Qnil) }));

0 commit comments

Comments
 (0)