Skip to content

Commit a0d0b84

Browse files
authored
ZJIT: Support invalidating constant patch points (ruby#13998)
1 parent 23000e7 commit a0d0b84

File tree

6 files changed

+108
-3
lines changed

6 files changed

+108
-3
lines changed

test/ruby/test_zjit.rb

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,48 @@ def test = X
889889
end
890890
end
891891

892+
def test_constant_invalidation
893+
assert_compiles '123', <<~RUBY, call_threshold: 2, insns: [:opt_getconstant_path]
894+
class C; end
895+
def test = C
896+
test
897+
test
898+
899+
C = 123
900+
test
901+
RUBY
902+
end
903+
904+
def test_constant_path_invalidation
905+
assert_compiles '["Foo::C", "Foo::C", "Bar::C"]', <<~RUBY, call_threshold: 2, insns: [:opt_getconstant_path]
906+
module A
907+
module B; end
908+
end
909+
910+
module Foo
911+
C = "Foo::C"
912+
end
913+
914+
module Bar
915+
C = "Bar::C"
916+
end
917+
918+
A::B = Foo
919+
920+
def test = A::B::C
921+
922+
result = []
923+
924+
result << test
925+
result << test
926+
927+
A::B = Bar
928+
929+
result << test
930+
result
931+
RUBY
932+
end
933+
892934
def test_dupn
893935
assert_compiles '[[1], [1, 1], :rhs, [nil, :rhs]]', <<~RUBY, insns: [:dupn]
894936
def test(array) = (array[1, 2] ||= :rhs)

vm_method.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ rb_clear_constant_cache_for_id(ID id)
148148
}
149149

150150
rb_yjit_constant_state_changed(id);
151+
rb_zjit_constant_state_changed(id);
151152
}
152153

153154
static void

zjit.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ void rb_zjit_profile_enable(const rb_iseq_t *iseq);
1414
void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop);
1515
void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme);
1616
void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq);
17+
void rb_zjit_constant_state_changed(ID id);
1718
void rb_zjit_iseq_mark(void *payload);
1819
void rb_zjit_iseq_update_references(void *payload);
1920
#else
@@ -24,6 +25,7 @@ static inline void rb_zjit_profile_enable(const rb_iseq_t *iseq) {}
2425
static inline void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop) {}
2526
static inline void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme) {}
2627
static inline void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq) {}
28+
static inline void rb_zjit_constant_state_changed(ID id) {}
2729
#endif // #if USE_YJIT
2830

2931
#endif // #ifndef ZJIT_H

zjit/src/codegen.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::ffi::{c_int};
44

55
use crate::asm::Label;
66
use crate::backend::current::{Reg, ALLOC_REGS};
7-
use crate::invariants::{track_bop_assumption, track_cme_assumption};
7+
use crate::invariants::{track_bop_assumption, track_cme_assumption, track_stable_constant_names_assumption};
88
use crate::gc::{get_or_create_iseq_payload, append_gc_offsets};
99
use crate::state::ZJITState;
1010
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
@@ -505,6 +505,10 @@ fn gen_patch_point(jit: &mut JITState, asm: &mut Assembler, invariant: &Invarian
505505
let side_exit_ptr = cb.resolve_label(label);
506506
track_cme_assumption(cme, code_ptr, side_exit_ptr);
507507
}
508+
Invariant::StableConstantNames { idlist } => {
509+
let side_exit_ptr = cb.resolve_label(label);
510+
track_stable_constant_names_assumption(idlist, code_ptr, side_exit_ptr);
511+
}
508512
_ => {
509513
debug!("ZJIT: gen_patch_point: unimplemented invariant {invariant:?}");
510514
return;

zjit/src/cruby.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ pub struct VALUE(pub usize);
259259
/// An interned string. See [ids] and methods this type.
260260
/// `0` is a sentinal value for IDs.
261261
#[repr(transparent)]
262-
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
262+
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
263263
pub struct ID(pub ::std::os::raw::c_ulong);
264264

265265
/// Pointer to an ISEQ

zjit/src/invariants.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{collections::{HashMap, HashSet}};
22

3-
use crate::{backend::lir::{asm_comment, Assembler}, cruby::{rb_callable_method_entry_t, ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr};
3+
use crate::{backend::lir::{asm_comment, Assembler}, cruby::{rb_callable_method_entry_t, ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag, ID}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr};
44

55
#[derive(Debug, Eq, Hash, PartialEq)]
66
struct Jump {
@@ -23,6 +23,9 @@ pub struct Invariants {
2323

2424
/// Map from CME to patch points that assume the method hasn't been redefined
2525
cme_patch_points: HashMap<*const rb_callable_method_entry_t, HashSet<Jump>>,
26+
27+
/// Map from constant ID to patch points that assume the constant hasn't been redefined
28+
constant_state_patch_points: HashMap<ID, HashSet<Jump>>,
2629
}
2730

2831
/// Called when a basic operator is redefined. Note that all the blocks assuming
@@ -116,6 +119,30 @@ pub fn track_cme_assumption(
116119
});
117120
}
118121

122+
/// Track a patch point for each constant name in a constant path assumption.
123+
pub fn track_stable_constant_names_assumption(
124+
idlist: *const ID,
125+
patch_point_ptr: CodePtr,
126+
side_exit_ptr: CodePtr
127+
) {
128+
let invariants = ZJITState::get_invariants();
129+
130+
let mut idx = 0;
131+
loop {
132+
let id = unsafe { *idlist.wrapping_add(idx) };
133+
if id.0 == 0 {
134+
break;
135+
}
136+
137+
invariants.constant_state_patch_points.entry(id).or_default().insert(Jump {
138+
from: patch_point_ptr,
139+
to: side_exit_ptr,
140+
});
141+
142+
idx += 1;
143+
}
144+
}
145+
119146
/// Called when a method is redefined. Invalidates all JIT code that depends on the CME.
120147
#[unsafe(no_mangle)]
121148
pub extern "C" fn rb_zjit_cme_invalidate(cme: *const rb_callable_method_entry_t) {
@@ -144,3 +171,32 @@ pub extern "C" fn rb_zjit_cme_invalidate(cme: *const rb_callable_method_entry_t)
144171
}
145172
});
146173
}
174+
175+
/// Called when a constant is redefined. Invalidates all JIT code that depends on the constant.
176+
#[unsafe(no_mangle)]
177+
pub extern "C" fn rb_zjit_constant_state_changed(id: ID) {
178+
// If ZJIT isn't enabled, do nothing
179+
if !zjit_enabled_p() {
180+
return;
181+
}
182+
183+
with_vm_lock(src_loc!(), || {
184+
let invariants = ZJITState::get_invariants();
185+
if let Some(jumps) = invariants.constant_state_patch_points.get(&id) {
186+
let cb = ZJITState::get_code_block();
187+
debug!("Constant state changed: {:?}", id);
188+
189+
// Invalidate all patch points for this constant ID
190+
for jump in jumps {
191+
cb.with_write_ptr(jump.from, |cb| {
192+
let mut asm = Assembler::new();
193+
asm_comment!(asm, "Constant state changed: {:?}", id);
194+
asm.jmp(jump.to.into());
195+
asm.compile(cb).expect("can write existing code");
196+
});
197+
}
198+
199+
cb.mark_all_executable();
200+
}
201+
});
202+
}

0 commit comments

Comments
 (0)