Skip to content

Commit 1b95e9c

Browse files
committed
Implement JIT-to-JIT calls (Shopify/zjit#109)
* Implement JIT-to-JIT calls * Use a closer dummy address for Arm64 * Revert an obsoleted change * Revert a few more obsoleted changes * Fix outdated comments * Explain PosMarkers for CCall * s/JIT code/machine code/ * Get rid of ParallelMov
1 parent 4f43a09 commit 1b95e9c

File tree

7 files changed

+370
-116
lines changed

7 files changed

+370
-116
lines changed

test/ruby/test_zjit.rb

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,21 @@ def test = callee + callee
389389
}
390390
end
391391

392+
def test_method_call
393+
assert_compiles '12', %q{
394+
def callee(a, b)
395+
a - b
396+
end
397+
398+
def test
399+
callee(4, 2) + 10
400+
end
401+
402+
test # profile test
403+
test
404+
}, call_threshold: 2
405+
end
406+
392407
def test_recursive_fact
393408
assert_compiles '[1, 6, 720]', %q{
394409
def fact(n)
@@ -401,6 +416,19 @@ def fact(n)
401416
}
402417
end
403418

419+
def test_profiled_fact
420+
assert_compiles '[1, 6, 720]', %q{
421+
def fact(n)
422+
if n == 0
423+
return 1
424+
end
425+
return n * fact(n-1)
426+
end
427+
fact(1) # profile fact
428+
[fact(0), fact(3), fact(6)]
429+
}, call_threshold: 3, num_profiles: 2
430+
end
431+
404432
def test_recursive_fib
405433
assert_compiles '[0, 2, 3]', %q{
406434
def fib(n)
@@ -413,11 +441,24 @@ def fib(n)
413441
}
414442
end
415443

444+
def test_profiled_fib
445+
assert_compiles '[0, 2, 3]', %q{
446+
def fib(n)
447+
if n < 2
448+
return n
449+
end
450+
return fib(n-1) + fib(n-2)
451+
end
452+
fib(3) # profile fib
453+
[fib(0), fib(3), fib(4)]
454+
}, call_threshold: 5, num_profiles: 3
455+
end
456+
416457
private
417458

418459
# Assert that every method call in `test_script` can be compiled by ZJIT
419460
# at a given call_threshold
420-
def assert_compiles(expected, test_script, call_threshold: 1)
461+
def assert_compiles(expected, test_script, **opts)
421462
pipe_fd = 3
422463

423464
script = <<~RUBY
@@ -429,7 +470,7 @@ def assert_compiles(expected, test_script, call_threshold: 1)
429470
IO.open(#{pipe_fd}).write(result.inspect)
430471
RUBY
431472

432-
status, out, err, actual = eval_with_jit(script, call_threshold:, pipe_fd:)
473+
status, out, err, actual = eval_with_jit(script, pipe_fd:, **opts)
433474

434475
message = "exited with status #{status.to_i}"
435476
message << "\nstdout:\n```\n#{out}```\n" unless out.empty?
@@ -440,10 +481,11 @@ def assert_compiles(expected, test_script, call_threshold: 1)
440481
end
441482

442483
# Run a Ruby process with ZJIT options and a pipe for writing test results
443-
def eval_with_jit(script, call_threshold: 1, timeout: 1000, pipe_fd:, debug: true)
484+
def eval_with_jit(script, call_threshold: 1, num_profiles: 1, timeout: 1000, pipe_fd:, debug: true)
444485
args = [
445486
"--disable-gems",
446487
"--zjit-call-threshold=#{call_threshold}",
488+
"--zjit-num-profiles=#{num_profiles}",
447489
]
448490
args << "--zjit-debug" if debug
449491
args << "-e" << script_shell_encode(script)

zjit/src/asm/mod.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,28 @@ impl CodeBlock {
111111
self.get_ptr(self.write_pos)
112112
}
113113

114+
/// Set the current write position from a pointer
115+
fn set_write_ptr(&mut self, code_ptr: CodePtr) {
116+
let pos = code_ptr.as_offset() - self.mem_block.borrow().start_ptr().as_offset();
117+
self.write_pos = pos.try_into().unwrap();
118+
}
119+
120+
/// Invoke a callback with write_ptr temporarily adjusted to a given address
121+
pub fn with_write_ptr(&mut self, code_ptr: CodePtr, callback: impl Fn(&mut CodeBlock)) {
122+
// Temporarily update the write_pos. Ignore the dropped_bytes flag at the old address.
123+
let old_write_pos = self.write_pos;
124+
let old_dropped_bytes = self.dropped_bytes;
125+
self.set_write_ptr(code_ptr);
126+
self.dropped_bytes = false;
127+
128+
// Invoke the callback
129+
callback(self);
130+
131+
// Restore the original write_pos and dropped_bytes flag.
132+
self.dropped_bytes = old_dropped_bytes;
133+
self.write_pos = old_write_pos;
134+
}
135+
114136
/// Get a (possibly dangling) direct pointer into the executable memory block
115137
pub fn get_ptr(&self, offset: usize) -> CodePtr {
116138
self.mem_block.borrow().start_ptr().add_bytes(offset)

zjit/src/backend/arm64/mod.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -491,19 +491,21 @@ impl Assembler
491491
// register.
492492
// Note: the iteration order is reversed to avoid corrupting x0,
493493
// which is both the return value and first argument register
494-
let mut args: Vec<(Reg, Opnd)> = vec![];
495-
for (idx, opnd) in opnds.into_iter().enumerate().rev() {
496-
// If the value that we're sending is 0, then we can use
497-
// the zero register, so in this case we'll just send
498-
// a UImm of 0 along as the argument to the move.
499-
let value = match opnd {
500-
Opnd::UImm(0) | Opnd::Imm(0) => Opnd::UImm(0),
501-
Opnd::Mem(_) => split_memory_address(asm, *opnd),
502-
_ => *opnd
503-
};
504-
args.push((C_ARG_OPNDS[idx].unwrap_reg(), value));
494+
if !opnds.is_empty() {
495+
let mut args: Vec<(Reg, Opnd)> = vec![];
496+
for (idx, opnd) in opnds.into_iter().enumerate().rev() {
497+
// If the value that we're sending is 0, then we can use
498+
// the zero register, so in this case we'll just send
499+
// a UImm of 0 along as the argument to the move.
500+
let value = match opnd {
501+
Opnd::UImm(0) | Opnd::Imm(0) => Opnd::UImm(0),
502+
Opnd::Mem(_) => split_memory_address(asm, *opnd),
503+
_ => *opnd
504+
};
505+
args.push((C_ARG_OPNDS[idx].unwrap_reg(), value));
506+
}
507+
asm.parallel_mov(args);
505508
}
506-
asm.parallel_mov(args);
507509

508510
// Now we push the CCall without any arguments so that it
509511
// just performs the call.

zjit/src/backend/lir.rs

Lines changed: 80 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,19 @@ pub enum Insn {
345345
CPushAll,
346346

347347
// C function call with N arguments (variadic)
348-
CCall { opnds: Vec<Opnd>, fptr: *const u8, out: Opnd },
348+
CCall {
349+
opnds: Vec<Opnd>,
350+
fptr: *const u8,
351+
/// Optional PosMarker to remember the start address of the C call.
352+
/// It's embedded here to insert the PosMarker after push instructions
353+
/// that are split from this CCall on alloc_regs().
354+
start_marker: Option<PosMarkerFn>,
355+
/// Optional PosMarker to remember the end address of the C call.
356+
/// It's embedded here to insert the PosMarker before pop instructions
357+
/// that are split from this CCall on alloc_regs().
358+
end_marker: Option<PosMarkerFn>,
359+
out: Opnd,
360+
},
349361

350362
// C function return
351363
CRet(Opnd),
@@ -1455,6 +1467,23 @@ impl Assembler
14551467
let mut asm = Assembler::new_with_label_names(take(&mut self.label_names), live_ranges.len());
14561468

14571469
while let Some((index, mut insn)) = iterator.next() {
1470+
let before_ccall = match (&insn, iterator.peek().map(|(_, insn)| insn)) {
1471+
(Insn::ParallelMov { .. }, Some(Insn::CCall { .. })) |
1472+
(Insn::CCall { .. }, _) if !pool.is_empty() => {
1473+
// If C_RET_REG is in use, move it to another register.
1474+
// This must happen before last-use registers are deallocated.
1475+
if let Some(vreg_idx) = pool.vreg_for(&C_RET_REG) {
1476+
let new_reg = pool.alloc_reg(vreg_idx).unwrap(); // TODO: support spill
1477+
asm.mov(Opnd::Reg(new_reg), C_RET_OPND);
1478+
pool.dealloc_reg(&C_RET_REG);
1479+
reg_mapping[vreg_idx] = Some(new_reg);
1480+
}
1481+
1482+
true
1483+
},
1484+
_ => false,
1485+
};
1486+
14581487
// Check if this is the last instruction that uses an operand that
14591488
// spans more than one instruction. In that case, return the
14601489
// allocated register to the pool.
@@ -1477,32 +1506,20 @@ impl Assembler
14771506
}
14781507
}
14791508

1480-
// If we're about to make a C call, save caller-saved registers
1481-
match (&insn, iterator.peek().map(|(_, insn)| insn)) {
1482-
(Insn::ParallelMov { .. }, Some(Insn::CCall { .. })) |
1483-
(Insn::CCall { .. }, _) if !pool.is_empty() => {
1484-
// If C_RET_REG is in use, move it to another register
1485-
if let Some(vreg_idx) = pool.vreg_for(&C_RET_REG) {
1486-
let new_reg = pool.alloc_reg(vreg_idx).unwrap(); // TODO: support spill
1487-
asm.mov(Opnd::Reg(new_reg), C_RET_OPND);
1488-
pool.dealloc_reg(&C_RET_REG);
1489-
reg_mapping[vreg_idx] = Some(new_reg);
1490-
}
1491-
1492-
// Find all live registers
1493-
saved_regs = pool.live_regs();
1509+
// Save caller-saved registers on a C call.
1510+
if before_ccall {
1511+
// Find all live registers
1512+
saved_regs = pool.live_regs();
14941513

1495-
// Save live registers
1496-
for &(reg, _) in saved_regs.iter() {
1497-
asm.cpush(Opnd::Reg(reg));
1498-
pool.dealloc_reg(&reg);
1499-
}
1500-
// On x86_64, maintain 16-byte stack alignment
1501-
if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 {
1502-
asm.cpush(Opnd::Reg(saved_regs.last().unwrap().0));
1503-
}
1514+
// Save live registers
1515+
for &(reg, _) in saved_regs.iter() {
1516+
asm.cpush(Opnd::Reg(reg));
1517+
pool.dealloc_reg(&reg);
1518+
}
1519+
// On x86_64, maintain 16-byte stack alignment
1520+
if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 {
1521+
asm.cpush(Opnd::Reg(saved_regs.last().unwrap().0));
15041522
}
1505-
_ => {},
15061523
}
15071524

15081525
// If the output VReg of this instruction is used by another instruction,
@@ -1590,13 +1607,24 @@ impl Assembler
15901607

15911608
// Push instruction(s)
15921609
let is_ccall = matches!(insn, Insn::CCall { .. });
1593-
if let Insn::ParallelMov { moves } = insn {
1594-
// Now that register allocation is done, it's ready to resolve parallel moves.
1595-
for (reg, opnd) in Self::resolve_parallel_moves(&moves) {
1596-
asm.load_into(Opnd::Reg(reg), opnd);
1610+
match insn {
1611+
Insn::ParallelMov { moves } => {
1612+
// Now that register allocation is done, it's ready to resolve parallel moves.
1613+
for (reg, opnd) in Self::resolve_parallel_moves(&moves) {
1614+
asm.load_into(Opnd::Reg(reg), opnd);
1615+
}
15971616
}
1598-
} else {
1599-
asm.push_insn(insn);
1617+
Insn::CCall { opnds, fptr, start_marker, end_marker, out } => {
1618+
// Split start_marker and end_marker here to avoid inserting push/pop between them.
1619+
if let Some(start_marker) = start_marker {
1620+
asm.push_insn(Insn::PosMarker(start_marker));
1621+
}
1622+
asm.push_insn(Insn::CCall { opnds, fptr, start_marker: None, end_marker: None, out });
1623+
if let Some(end_marker) = end_marker {
1624+
asm.push_insn(Insn::PosMarker(end_marker));
1625+
}
1626+
}
1627+
_ => asm.push_insn(insn),
16001628
}
16011629

16021630
// After a C call, restore caller-saved registers
@@ -1720,38 +1748,30 @@ impl Assembler {
17201748
self.push_insn(Insn::Breakpoint);
17211749
}
17221750

1751+
/// Call a C function without PosMarkers
17231752
pub fn ccall(&mut self, fptr: *const u8, opnds: Vec<Opnd>) -> Opnd {
1724-
/*
1725-
// Let vm_check_canary() assert this ccall's leafness if leaf_ccall is set
1726-
let canary_opnd = self.set_stack_canary(&opnds);
1727-
1728-
let old_temps = self.ctx.get_reg_mapping(); // with registers
1729-
// Spill stack temp registers since they are caller-saved registers.
1730-
// Note that this doesn't spill stack temps that are already popped
1731-
// but may still be used in the C arguments.
1732-
self.spill_regs();
1733-
let new_temps = self.ctx.get_reg_mapping(); // all spilled
1734-
1735-
// Temporarily manipulate RegMappings so that we can use registers
1736-
// to pass stack operands that are already spilled above.
1737-
self.ctx.set_reg_mapping(old_temps);
1738-
*/
1739-
1740-
// Call a C function
17411753
let out = self.new_vreg(Opnd::match_num_bits(&opnds));
1742-
self.push_insn(Insn::CCall { fptr, opnds, out });
1743-
1744-
/*
1745-
// Registers in old_temps may be clobbered by the above C call,
1746-
// so rollback the manipulated RegMappings to a spilled version.
1747-
self.ctx.set_reg_mapping(new_temps);
1748-
1749-
// Clear the canary after use
1750-
if let Some(canary_opnd) = canary_opnd {
1751-
self.mov(canary_opnd, 0.into());
1752-
}
1753-
*/
1754+
self.push_insn(Insn::CCall { fptr, opnds, start_marker: None, end_marker: None, out });
1755+
out
1756+
}
17541757

1758+
/// Call a C function with PosMarkers. This is used for recording the start and end
1759+
/// addresses of the C call and rewriting it with a different function address later.
1760+
pub fn ccall_with_pos_markers(
1761+
&mut self,
1762+
fptr: *const u8,
1763+
opnds: Vec<Opnd>,
1764+
start_marker: impl Fn(CodePtr, &CodeBlock) + 'static,
1765+
end_marker: impl Fn(CodePtr, &CodeBlock) + 'static,
1766+
) -> Opnd {
1767+
let out = self.new_vreg(Opnd::match_num_bits(&opnds));
1768+
self.push_insn(Insn::CCall {
1769+
fptr,
1770+
opnds,
1771+
start_marker: Some(Box::new(start_marker)),
1772+
end_marker: Some(Box::new(end_marker)),
1773+
out,
1774+
});
17551775
out
17561776
}
17571777

zjit/src/backend/x86_64/mod.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ impl From<&Opnd> for X86Opnd {
8282
/// This has the same number of registers for x86_64 and arm64.
8383
/// SCRATCH_REG is excluded.
8484
pub const ALLOC_REGS: &'static [Reg] = &[
85-
RSI_REG,
8685
RDI_REG,
86+
RSI_REG,
8787
RDX_REG,
8888
RCX_REG,
8989
R8_REG,
@@ -338,11 +338,13 @@ impl Assembler
338338
assert!(opnds.len() <= C_ARG_OPNDS.len());
339339

340340
// Load each operand into the corresponding argument register.
341-
let mut args: Vec<(Reg, Opnd)> = vec![];
342-
for (idx, opnd) in opnds.into_iter().enumerate() {
343-
args.push((C_ARG_OPNDS[idx].unwrap_reg(), *opnd));
341+
if !opnds.is_empty() {
342+
let mut args: Vec<(Reg, Opnd)> = vec![];
343+
for (idx, opnd) in opnds.into_iter().enumerate() {
344+
args.push((C_ARG_OPNDS[idx].unwrap_reg(), *opnd));
345+
}
346+
asm.parallel_mov(args);
344347
}
345-
asm.parallel_mov(args);
346348

347349
// Now we push the CCall without any arguments so that it
348350
// just performs the call.

0 commit comments

Comments
 (0)