Skip to content

Commit 4c231bf

Browse files
committed
runtime: add tail call support
1 parent 2e1e9c7 commit 4c231bf

File tree

4 files changed

+210
-35
lines changed

4 files changed

+210
-35
lines changed

runtime/src/classes/java/lang/invoke/MethodHandle.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,22 @@ mod _dynamic {
132132
}
133133
}
134134

135+
fn invoke_target(frame: &mut Frame, method: &'static Method) {
136+
let mut parameters_count = method.parameter_count() as usize;
137+
if !method.is_static() {
138+
parameters_count += 1;
139+
}
140+
141+
unsafe {
142+
frame.thread().tail_call(method);
143+
frame.stack_mut().seek_stack_pointer(1);
144+
}
145+
}
146+
135147
pub fn invoke_basic(frame: &mut Frame, entry: MethodEntry) {
136148
// Add 1 to the parameters size, since it doesn't account for `this`
137149
let parameters_count = (entry.parameters_stack_size as usize) + 1;
150+
debug_assert!(frame.stack().len() >= parameters_count);
138151

139152
let receiver = frame.stack().at(parameters_count).expect_reference();
140153
if receiver.is_null() {
@@ -145,11 +158,7 @@ mod _dynamic {
145158
return;
146159
};
147160

148-
let call_args = frame.stack_mut().popn(parameters_count);
149-
let call_args =
150-
unsafe { LocalStack::new_with_args(call_args, target_method.code.max_locals as usize) };
151-
let ret = java_call!(@WITH_ARGS_LIST frame.thread(), target_method, call_args);
152-
morph_return_value(frame, ret);
161+
invoke_target(frame, target_method);
153162
}
154163

155164
pub fn invoke_exact(frame: &mut Frame, entry: MethodEntry) {
@@ -177,7 +186,7 @@ mod _dynamic {
177186
let Some((_appendix, target_method)) = appendix_and_target_method(frame) else {
178187
return;
179188
};
180-
MethodInvoker::invoke_virtual(frame, target_method);
189+
invoke_target(frame, target_method);
181190
}
182191

183192
pub fn link_to_special(frame: &mut Frame) {

runtime/src/interpreter.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::thread::exceptions::{
1919
Exception, ExceptionKind, Throws, handle_exception, throw, throw_with_ret,
2020
};
2121
use crate::thread::frame::{Frame, PcUpdateStrategy};
22-
use crate::thread::{JavaThread, exceptions};
22+
use crate::thread::{ControlFlow, JavaThread, exceptions};
2323
use crate::{classes, java_call};
2424

2525
use std::cmp::Ordering;
@@ -240,7 +240,12 @@ macro_rules! comparisons {
240240
macro_rules! control_return {
241241
($frame:ident, $instruction:ident) => {{
242242
let thread = $frame.thread();
243-
thread.drop_to_previous_frame(None);
243+
if !$frame.in_tail_call() {
244+
thread.drop_to_previous_frame(None);
245+
return;
246+
}
247+
248+
thread.set_control_flow(ControlFlow::Break);
244249
return;
245250
}};
246251
($frame:ident, $instruction:ident, $return_ty:ident) => {{
@@ -257,7 +262,12 @@ macro_rules! control_return {
257262
}
258263

259264
let thread = $frame.thread();
260-
thread.drop_to_previous_frame(Some(value));
265+
if !$frame.in_tail_call() {
266+
thread.drop_to_previous_frame(Some(value));
267+
return;
268+
}
269+
270+
thread.set_control_flow(ControlFlow::Break);
261271
return;
262272
}};
263273
}
@@ -1091,7 +1101,7 @@ impl Interpreter {
10911101
MethodInvoker::invoke_virtual(frame, entry.method);
10921102
return;
10931103
}
1094-
1104+
10951105
if entry.method.class().is_subclass_of(crate::globals::classes::java_lang_invoke_MethodHandle()) {
10961106
let Some(MethodEntryPoint::MethodHandleInvoker(mh_invoker)) = entry.method.entry_point() else {
10971107
panic!("Expected MethodHandleInvoker entry point");

runtime/src/thread/frame/mod.rs

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ use crate::thread::exceptions::{ExceptionKind, Throws};
1010

1111
use std::cell::UnsafeCell;
1212
use std::fmt::{Debug, Formatter};
13-
use std::mem;
1413
use std::sync::atomic::{AtomicIsize, Ordering};
1514

1615
use common::int_types::{s1, s2, s4, u1, u2, u4};
16+
use instructions::StackLike;
1717

1818
// https://docs.oracle.com/javase/specs/jvms/se23/html/jvms-2.html#jvms-2.6
1919
#[rustfmt::skip]
@@ -26,12 +26,24 @@ pub struct Frame {
2626
stack: OperandStack,
2727
// and a reference to the run-time constant pool (§2.5.5)
2828
constant_pool: &'static ConstantPool,
29-
method: &'static Method,
29+
30+
// Fields outside the spec:
31+
32+
method: &'static Method,
3033
thread: UnsafeCell<*const JavaThread>,
3134

3235
// Used to remember the last pc when we return to a frame after a method invocation
3336
cached_pc: AtomicIsize,
34-
pub depth: isize,
37+
38+
// TODO: depth should never be > 5, could be packed with flags
39+
// Extra depth within the current instruction
40+
//
41+
// When parsing a bytecode instruction, `pc` stays at the beginning of the instruction. This keeps
42+
// track of any additional bytes we read *after* that bytecode (e.g. arguments for the instruction).
43+
//
44+
// The depth is used at the end of an instruction to calculate the offset to the next instruction.
45+
depth: u16,
46+
flags: u8,
3547
}
3648

3749
impl Debug for Frame {
@@ -45,6 +57,15 @@ impl Debug for Frame {
4557
}
4658
}
4759

60+
// Flags
61+
impl Frame {
62+
const IN_TAIL_CALL: u8 = 0b1;
63+
64+
pub fn in_tail_call(&self) -> bool {
65+
self.flags & Self::IN_TAIL_CALL != 0
66+
}
67+
}
68+
4869
impl Frame {
4970
/// Create a new `Frame` for a [`Method`] invocation
5071
///
@@ -70,8 +91,68 @@ impl Frame {
7091
thread: UnsafeCell::new(&raw const *thread),
7192
cached_pc: AtomicIsize::default(),
7293
depth: 0,
94+
flags: 0,
7395
})
96+
}
97+
98+
/// Reuse this frame for a tail method call
99+
///
100+
/// This will replace the original [`LocalStack`] and return it. It must be retained and used in
101+
/// a subsequent call to [`Self::reset_from_tail_call()`].
102+
///
103+
/// # Safety
104+
///
105+
/// The current [`OperandStack`] is retained (including its current position), so the stack
106+
/// ***must*** be setup correctly for the target `method`.
107+
pub(in crate::thread) unsafe fn swap_for_tail_call(
108+
&mut self,
109+
method: &'static Method,
110+
) -> LocalStack {
111+
assert!(method.parameter_count() as usize <= self.locals.total_slots());
112+
assert!(!self.has_stashed_depth());
113+
114+
let mut parameter_count = method.parameter_count() as usize;
115+
if !method.is_static() {
116+
// receiver
117+
parameter_count += 1;
74118
}
119+
120+
let locals = unsafe {
121+
LocalStack::new_with_args(
122+
self.stack_mut().popn(parameter_count),
123+
method.code.max_locals as usize,
124+
)
125+
};
126+
127+
let old_locals = core::mem::replace(&mut self.locals, locals);
128+
self.constant_pool = method
129+
.class()
130+
.constant_pool()
131+
.expect("Methods do not exist on array classes");
132+
133+
self.depth = self.depth << 8;
134+
self.method = method;
135+
self.flags |= Self::IN_TAIL_CALL;
136+
137+
old_locals
138+
}
139+
140+
/// Restore this frame to its state prior to a tail call
141+
///
142+
/// NOTE: The [`OperandStack`] will be left in whatever state the prior method returned with.
143+
pub(in crate::thread) fn reset_from_tail_call(
144+
&mut self,
145+
old_locals: LocalStack,
146+
method: &'static Method,
147+
) {
148+
self.locals = old_locals;
149+
self.constant_pool = method
150+
.class()
151+
.constant_pool()
152+
.expect("Methods do not exist on array classes");
153+
self.depth = self.depth >> 8;
154+
self.method = method;
155+
self.flags |= !Self::IN_TAIL_CALL;
75156
}
76157
}
77158

@@ -127,6 +208,19 @@ impl Frame {
127208
pub fn stashed_pc(&self) -> isize {
128209
self.cached_pc.load(Ordering::Relaxed)
129210
}
211+
212+
fn depth(&self) -> isize {
213+
(self.depth & 0b1111_1111) as isize
214+
}
215+
216+
fn inc_depth(&mut self) {
217+
assert!(self.depth() <= u8::MAX as isize);
218+
self.depth = (self.depth & 0xFF00) | ((self.depth & 0x00FF) + 1);
219+
}
220+
221+
fn has_stashed_depth(&self) -> bool {
222+
(self.depth >> 8) > 0
223+
}
130224
}
131225

132226
// Setters
@@ -159,8 +253,8 @@ impl Frame {
159253
pc = thread.pc.load(Ordering::Relaxed);
160254
}
161255

162-
let ret = self.method.code.code[(pc + self.depth) as usize];
163-
self.depth += 1;
256+
let ret = self.method.code.code[(pc + self.depth()) as usize];
257+
self.inc_depth();
164258

165259
ret
166260
}
@@ -206,17 +300,20 @@ impl Frame {
206300
///
207301
/// This is used in the `tableswitch` and `lookupswitch` instructions.
208302
pub fn skip_padding(&mut self) {
209-
let current_pc = self.thread().pc.load(Ordering::Relaxed) + self.depth;
303+
let current_pc = self.thread().pc.load(Ordering::Relaxed) + self.depth();
210304

211305
let mut pc = current_pc;
212306
while pc % 4 != 0 {
213307
pc += 1;
214-
self.depth += 1;
308+
self.inc_depth();
215309
}
216310
}
217311

218312
pub fn take_cached_depth(&mut self) -> isize {
219-
mem::replace(&mut self.depth, 0)
313+
let depth = self.depth();
314+
self.depth = 0;
315+
316+
depth
220317
}
221318

222319
/// Commit the [pc] to the current [`JavaThread`]
@@ -230,7 +327,7 @@ impl Frame {
230327
let _ = self.thread().pc.fetch_add(off, Ordering::Relaxed);
231328
},
232329
PcUpdateStrategy::FromInstruction => {
233-
let _ = self.thread().pc.fetch_add(self.depth, Ordering::Relaxed);
330+
let _ = self.thread().pc.fetch_add(self.depth(), Ordering::Relaxed);
234331
},
235332
}
236333

0 commit comments

Comments
 (0)