Skip to content

Commit ff89e47

Browse files
committed
ZJIT: Specialize Module#=== and Kernel#is_a? into IsA
1 parent e0bb3fb commit ff89e47

File tree

4 files changed

+225
-1
lines changed

4 files changed

+225
-1
lines changed

zjit/src/codegen.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
472472
Insn::ArrayInclude { elements, target, state } => gen_array_include(jit, asm, opnds!(elements), opnd!(target), &function.frame_state(*state)),
473473
&Insn::DupArrayInclude { ary, target, state } => gen_dup_array_include(jit, asm, ary, opnd!(target), &function.frame_state(state)),
474474
Insn::ArrayHash { elements, state } => gen_opt_newarray_hash(jit, asm, opnds!(elements), &function.frame_state(*state)),
475+
&Insn::IsA { val, class } => gen_is_a(asm, opnd!(val), opnd!(class)),
475476
&Insn::ArrayMax { state, .. }
476477
| &Insn::FixnumDiv { state, .. }
477478
| &Insn::Throw { state, .. }
@@ -1520,6 +1521,10 @@ fn gen_dup_array_include(
15201521
)
15211522
}
15221523

1524+
fn gen_is_a(asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd {
1525+
asm_ccall!(asm, rb_obj_is_kind_of, obj, class)
1526+
}
1527+
15231528
/// Compile a new hash instruction
15241529
fn gen_new_hash(
15251530
jit: &mut JITState,

zjit/src/cruby_methods.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ pub fn init() -> Annotations {
197197
annotate!(rb_mKernel, "itself", inline_kernel_itself);
198198
annotate!(rb_mKernel, "block_given?", inline_kernel_block_given_p);
199199
annotate!(rb_mKernel, "===", inline_eqq);
200+
annotate!(rb_mKernel, "is_a?", inline_kernel_is_a_p);
200201
annotate!(rb_cString, "bytesize", inline_string_bytesize);
201202
annotate!(rb_cString, "size", types::Fixnum, no_gc, leaf, elidable);
202203
annotate!(rb_cString, "length", types::Fixnum, no_gc, leaf, elidable);
@@ -206,7 +207,7 @@ pub fn init() -> Annotations {
206207
annotate!(rb_cString, "<<", inline_string_append);
207208
annotate!(rb_cString, "==", inline_string_eq);
208209
annotate!(rb_cModule, "name", types::StringExact.union(types::NilClass), no_gc, leaf, elidable);
209-
annotate!(rb_cModule, "===", types::BoolExact, no_gc, leaf);
210+
annotate!(rb_cModule, "===", inline_module_eqq, types::BoolExact, no_gc, leaf);
210211
annotate!(rb_cArray, "length", types::Fixnum, no_gc, leaf, elidable);
211212
annotate!(rb_cArray, "size", types::Fixnum, no_gc, leaf, elidable);
212213
annotate!(rb_cArray, "empty?", types::BoolExact, no_gc, leaf, elidable);
@@ -447,6 +448,15 @@ fn inline_string_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::Ins
447448
None
448449
}
449450

451+
fn inline_module_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
452+
let &[other] = args else { return None; };
453+
if fun.is_a(recv, types::Class) {
454+
let result = fun.push_insn(block, hir::Insn::IsA { val: other, class: recv });
455+
return Some(result);
456+
}
457+
None
458+
}
459+
450460
fn inline_integer_succ(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> {
451461
if !args.is_empty() { return None; }
452462
if fun.likely_a(recv, types::Fixnum, state) {
@@ -613,6 +623,15 @@ fn inline_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, a
613623
Some(result)
614624
}
615625

626+
fn inline_kernel_is_a_p(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
627+
let &[other] = args else { return None; };
628+
if fun.is_a(other, types::Class) {
629+
let result = fun.push_insn(block, hir::Insn::IsA { val: recv, class: other });
630+
return Some(result);
631+
}
632+
None
633+
}
634+
616635
fn inline_kernel_nil_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
617636
if !args.is_empty() { return None; }
618637
Some(fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qfalse) }))

zjit/src/hir.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,9 @@ pub enum Insn {
721721
/// Test the bit at index of val, a Fixnum.
722722
/// Return Qtrue if the bit is set, else Qfalse.
723723
FixnumBitCheck { val: InsnId, index: u8 },
724+
/// Return Qtrue if `val` is an instance of `class`, else Qfalse.
725+
/// Equivalent to `class_search_ancestor(CLASS_OF(val), class)`.
726+
IsA { val: InsnId, class: InsnId },
724727

725728
/// Get a global variable named `id`
726729
GetGlobal { id: ID, state: InsnId },
@@ -1000,6 +1003,7 @@ impl Insn {
10001003
Insn::BoxFixnum { .. } => false,
10011004
Insn::BoxBool { .. } => false,
10021005
Insn::IsBitEqual { .. } => false,
1006+
Insn::IsA { .. } => false,
10031007
_ => true,
10041008
}
10051009
}
@@ -1324,6 +1328,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
13241328
}
13251329
Insn::IncrCounter(counter) => write!(f, "IncrCounter {counter:?}"),
13261330
Insn::CheckInterrupts { .. } => write!(f, "CheckInterrupts"),
1331+
Insn::IsA { val, class } => write!(f, "IsA {val}, {class}"),
13271332
}
13281333
}
13291334
}
@@ -1946,6 +1951,7 @@ impl Function {
19461951
&ArrayExtend { left, right, state } => ArrayExtend { left: find!(left), right: find!(right), state },
19471952
&ArrayPush { array, val, state } => ArrayPush { array: find!(array), val: find!(val), state },
19481953
&CheckInterrupts { state } => CheckInterrupts { state },
1954+
&IsA { val, class } => IsA { val: find!(val), class: find!(class) },
19491955
}
19501956
}
19511957

@@ -2095,6 +2101,7 @@ impl Function {
20952101
// The type of Snapshot doesn't really matter; it's never materialized. It's used only
20962102
// as a reference for FrameState, which we use to generate side-exit code.
20972103
Insn::Snapshot { .. } => types::Any,
2104+
Insn::IsA { .. } => types::BoolExact,
20982105
}
20992106
}
21002107

@@ -3622,6 +3629,10 @@ impl Function {
36223629
&Insn::ObjectAllocClass { state, .. } |
36233630
&Insn::SideExit { state, .. } => worklist.push_back(state),
36243631
&Insn::UnboxFixnum { val } => worklist.push_back(val),
3632+
&Insn::IsA { val, class } => {
3633+
worklist.push_back(val);
3634+
worklist.push_back(class);
3635+
}
36253636
}
36263637
}
36273638

@@ -4314,6 +4325,10 @@ impl Function {
43144325
self.assert_subtype(insn_id, index, types::Fixnum)?;
43154326
self.assert_subtype(insn_id, value, types::Fixnum)
43164327
}
4328+
Insn::IsA { val, class } => {
4329+
self.assert_subtype(insn_id, val, types::BasicObject)?;
4330+
self.assert_subtype(insn_id, class, types::Class)
4331+
}
43174332
}
43184333
}
43194334

zjit/src/hir/opt_tests.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7983,6 +7983,191 @@ mod hir_opt_tests {
79837983
");
79847984
}
79857985

7986+
#[test]
7987+
fn test_specialize_class_eqq() {
7988+
eval(r#"
7989+
def test(o) = String === o
7990+
test("asdf")
7991+
"#);
7992+
assert_snapshot!(hir_string("test"), @r"
7993+
fn test@<compiled>:2:
7994+
bb0():
7995+
EntryPoint interpreter
7996+
v1:BasicObject = LoadSelf
7997+
v2:BasicObject = GetLocal l0, SP@4
7998+
Jump bb2(v1, v2)
7999+
bb1(v5:BasicObject, v6:BasicObject):
8000+
EntryPoint JIT(0)
8001+
Jump bb2(v5, v6)
8002+
bb2(v8:BasicObject, v9:BasicObject):
8003+
PatchPoint SingleRactorMode
8004+
PatchPoint StableConstantNames(0x1000, String)
8005+
v26:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
8006+
PatchPoint NoEPEscape(test)
8007+
PatchPoint MethodRedefined(Class@0x1010, ===@0x1018, cme:0x1020)
8008+
PatchPoint NoSingletonClass(Class@0x1010)
8009+
v30:BoolExact = IsA v9, v26
8010+
IncrCounter inline_cfunc_optimized_send_count
8011+
CheckInterrupts
8012+
Return v30
8013+
");
8014+
}
8015+
8016+
#[test]
8017+
fn test_dont_specialize_module_eqq() {
8018+
eval(r#"
8019+
def test(o) = Kernel === o
8020+
test("asdf")
8021+
"#);
8022+
assert_snapshot!(hir_string("test"), @r"
8023+
fn test@<compiled>:2:
8024+
bb0():
8025+
EntryPoint interpreter
8026+
v1:BasicObject = LoadSelf
8027+
v2:BasicObject = GetLocal l0, SP@4
8028+
Jump bb2(v1, v2)
8029+
bb1(v5:BasicObject, v6:BasicObject):
8030+
EntryPoint JIT(0)
8031+
Jump bb2(v5, v6)
8032+
bb2(v8:BasicObject, v9:BasicObject):
8033+
PatchPoint SingleRactorMode
8034+
PatchPoint StableConstantNames(0x1000, Kernel)
8035+
v26:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
8036+
PatchPoint NoEPEscape(test)
8037+
PatchPoint MethodRedefined(Module@0x1010, ===@0x1018, cme:0x1020)
8038+
PatchPoint NoSingletonClass(Module@0x1010)
8039+
IncrCounter inline_cfunc_optimized_send_count
8040+
v31:BoolExact = CCall Module#===@0x1048, v26, v9
8041+
CheckInterrupts
8042+
Return v31
8043+
");
8044+
}
8045+
8046+
#[test]
8047+
fn test_specialize_is_a_class() {
8048+
eval(r#"
8049+
def test(o) = o.is_a?(String)
8050+
test("asdf")
8051+
"#);
8052+
assert_snapshot!(hir_string("test"), @r"
8053+
fn test@<compiled>:2:
8054+
bb0():
8055+
EntryPoint interpreter
8056+
v1:BasicObject = LoadSelf
8057+
v2:BasicObject = GetLocal l0, SP@4
8058+
Jump bb2(v1, v2)
8059+
bb1(v5:BasicObject, v6:BasicObject):
8060+
EntryPoint JIT(0)
8061+
Jump bb2(v5, v6)
8062+
bb2(v8:BasicObject, v9:BasicObject):
8063+
PatchPoint SingleRactorMode
8064+
PatchPoint StableConstantNames(0x1000, String)
8065+
v24:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
8066+
PatchPoint MethodRedefined(String@0x1008, is_a?@0x1010, cme:0x1018)
8067+
PatchPoint NoSingletonClass(String@0x1008)
8068+
v28:StringExact = GuardType v9, StringExact
8069+
v29:BoolExact = IsA v28, v24
8070+
IncrCounter inline_cfunc_optimized_send_count
8071+
CheckInterrupts
8072+
Return v29
8073+
");
8074+
}
8075+
8076+
#[test]
8077+
fn test_dont_specialize_is_a_module() {
8078+
eval(r#"
8079+
def test(o) = o.is_a?(Kernel)
8080+
test("asdf")
8081+
"#);
8082+
assert_snapshot!(hir_string("test"), @r"
8083+
fn test@<compiled>:2:
8084+
bb0():
8085+
EntryPoint interpreter
8086+
v1:BasicObject = LoadSelf
8087+
v2:BasicObject = GetLocal l0, SP@4
8088+
Jump bb2(v1, v2)
8089+
bb1(v5:BasicObject, v6:BasicObject):
8090+
EntryPoint JIT(0)
8091+
Jump bb2(v5, v6)
8092+
bb2(v8:BasicObject, v9:BasicObject):
8093+
PatchPoint SingleRactorMode
8094+
PatchPoint StableConstantNames(0x1000, Kernel)
8095+
v24:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
8096+
PatchPoint MethodRedefined(String@0x1010, is_a?@0x1018, cme:0x1020)
8097+
PatchPoint NoSingletonClass(String@0x1010)
8098+
v28:StringExact = GuardType v9, StringExact
8099+
v29:BasicObject = CCallWithFrame Kernel#is_a?@0x1048, v28, v24
8100+
CheckInterrupts
8101+
Return v29
8102+
");
8103+
}
8104+
8105+
#[test]
8106+
fn test_elide_is_a() {
8107+
eval(r#"
8108+
def test(o)
8109+
o.is_a?(Integer)
8110+
5
8111+
end
8112+
test("asdf")
8113+
"#);
8114+
assert_snapshot!(hir_string("test"), @r"
8115+
fn test@<compiled>:3:
8116+
bb0():
8117+
EntryPoint interpreter
8118+
v1:BasicObject = LoadSelf
8119+
v2:BasicObject = GetLocal l0, SP@4
8120+
Jump bb2(v1, v2)
8121+
bb1(v5:BasicObject, v6:BasicObject):
8122+
EntryPoint JIT(0)
8123+
Jump bb2(v5, v6)
8124+
bb2(v8:BasicObject, v9:BasicObject):
8125+
PatchPoint SingleRactorMode
8126+
PatchPoint StableConstantNames(0x1000, Integer)
8127+
v28:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
8128+
PatchPoint MethodRedefined(String@0x1010, is_a?@0x1018, cme:0x1020)
8129+
PatchPoint NoSingletonClass(String@0x1010)
8130+
v32:StringExact = GuardType v9, StringExact
8131+
IncrCounter inline_cfunc_optimized_send_count
8132+
v21:Fixnum[5] = Const Value(5)
8133+
CheckInterrupts
8134+
Return v21
8135+
");
8136+
}
8137+
8138+
#[test]
8139+
fn test_elide_class_eqq() {
8140+
eval(r#"
8141+
def test(o)
8142+
Integer === o
8143+
5
8144+
end
8145+
test("asdf")
8146+
"#);
8147+
assert_snapshot!(hir_string("test"), @r"
8148+
fn test@<compiled>:3:
8149+
bb0():
8150+
EntryPoint interpreter
8151+
v1:BasicObject = LoadSelf
8152+
v2:BasicObject = GetLocal l0, SP@4
8153+
Jump bb2(v1, v2)
8154+
bb1(v5:BasicObject, v6:BasicObject):
8155+
EntryPoint JIT(0)
8156+
Jump bb2(v5, v6)
8157+
bb2(v8:BasicObject, v9:BasicObject):
8158+
PatchPoint SingleRactorMode
8159+
PatchPoint StableConstantNames(0x1000, Integer)
8160+
v30:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
8161+
PatchPoint NoEPEscape(test)
8162+
PatchPoint MethodRedefined(Class@0x1010, ===@0x1018, cme:0x1020)
8163+
PatchPoint NoSingletonClass(Class@0x1010)
8164+
IncrCounter inline_cfunc_optimized_send_count
8165+
v23:Fixnum[5] = Const Value(5)
8166+
CheckInterrupts
8167+
Return v23
8168+
");
8169+
}
8170+
79868171
#[test]
79878172
fn counting_complex_feature_use_for_fallback() {
79888173
eval("

0 commit comments

Comments
 (0)