Skip to content

Commit f7fe436

Browse files
authored
ZJIT: Optimize ObjToString with type guards (ruby#14469)
* failing test for ObjToString optimization with GuardType * profile ObjToString receiver and rewrite with guard * adjust integration tests for objtostring type guard optimization * Implement new GuardTypeNot HIR; objtostring sends to_s directly on profiled nonstrings * codegen for GuardTypeNot * typo fixes * better name for tests; fix side exit reason for GuardTypeNot * revert accidental change * make bindgen * Fix is_string to identify subclasses of String; fix codegen for identifying if val is String
1 parent fe362be commit f7fe436

File tree

7 files changed

+237
-23
lines changed

7 files changed

+237
-23
lines changed

insns.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ objtostring
939939
(VALUE recv)
940940
(VALUE val)
941941
// attr bool leaf = false;
942+
// attr bool zjit_profile = true;
942943
{
943944
val = vm_objtostring(GET_ISEQ(), recv, cd);
944945

test/ruby/test_zjit.rb

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,6 +2088,61 @@ def test(str)
20882088
}
20892089
end
20902090

2091+
def test_objtostring_profiled_string_fastpath
2092+
assert_compiles '"foo"', %q{
2093+
def test(str)
2094+
"#{str}"
2095+
end
2096+
test('foo'); test('foo') # profile as string
2097+
}, call_threshold: 2
2098+
end
2099+
2100+
def test_objtostring_profiled_string_subclass_fastpath
2101+
assert_compiles '"foo"', %q{
2102+
class MyString < String; end
2103+
2104+
def test(str)
2105+
"#{str}"
2106+
end
2107+
2108+
foo = MyString.new("foo")
2109+
test(foo); test(foo) # still profiles as string
2110+
}, call_threshold: 2
2111+
end
2112+
2113+
def test_objtostring_profiled_string_fastpath_exits_on_nonstring
2114+
assert_compiles '"1"', %q{
2115+
def test(str)
2116+
"#{str}"
2117+
end
2118+
2119+
test('foo') # profile as string
2120+
test(1)
2121+
}, call_threshold: 2
2122+
end
2123+
2124+
def test_objtostring_profiled_nonstring_calls_to_s
2125+
assert_compiles '"[1, 2, 3]"', %q{
2126+
def test(str)
2127+
"#{str}"
2128+
end
2129+
2130+
test([1,2,3]); # profile as nonstring
2131+
test([1,2,3]);
2132+
}, call_threshold: 2
2133+
end
2134+
2135+
def test_objtostring_profiled_nonstring_guard_exits_when_string
2136+
assert_compiles '"foo"', %q{
2137+
def test(str)
2138+
"#{str}"
2139+
end
2140+
2141+
test([1,2,3]); # profiles as nonstring
2142+
test('foo');
2143+
}, call_threshold: 2
2144+
end
2145+
20912146
def test_string_bytesize_with_guard
20922147
assert_compiles '5', %q{
20932148
def test(str)

zjit/src/codegen.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
389389
&Insn::IsMethodCfunc { val, cd, cfunc } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc),
390390
Insn::Test { val } => gen_test(asm, opnd!(val)),
391391
Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
392+
Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
392393
Insn::GuardBitEquals { val, expected, state } => gen_guard_bit_equals(jit, asm, opnd!(val), *expected, &function.frame_state(*state)),
393394
Insn::PatchPoint { invariant, state } => no_output!(gen_patch_point(jit, asm, invariant, &function.frame_state(*state))),
394395
Insn::CCall { cfun, args, name: _, return_type: _, elidable: _ } => gen_ccall(asm, *cfun, opnds!(args)),
@@ -1375,6 +1376,26 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard
13751376

13761377
asm.cmp(klass, Opnd::Value(expected_class));
13771378
asm.jne(side_exit);
1379+
} else if guard_type.is_subtype(types::String) {
1380+
let side = side_exit(jit, state, GuardType(guard_type));
1381+
1382+
// Check special constant
1383+
asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64));
1384+
asm.jnz(side.clone());
1385+
1386+
// Check false
1387+
asm.cmp(val, Qfalse.into());
1388+
asm.je(side.clone());
1389+
1390+
let val = match val {
1391+
Opnd::Reg(_) | Opnd::VReg { .. } => val,
1392+
_ => asm.load(val),
1393+
};
1394+
1395+
let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS));
1396+
let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64));
1397+
asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64));
1398+
asm.jne(side);
13781399
} else if guard_type.bit_equal(types::HeapObject) {
13791400
let side_exit = side_exit(jit, state, GuardType(guard_type));
13801401
asm.cmp(val, Opnd::Value(Qfalse));
@@ -1387,6 +1408,38 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard
13871408
val
13881409
}
13891410

1411+
fn gen_guard_type_not(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: &FrameState) -> lir::Opnd {
1412+
if guard_type.is_subtype(types::String) {
1413+
// We only exit if val *is* a String. Otherwise we fall through.
1414+
let cont = asm.new_label("guard_type_not_string_cont");
1415+
let side = side_exit(jit, state, GuardTypeNot(guard_type));
1416+
1417+
// Continue if special constant (not string)
1418+
asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64));
1419+
asm.jnz(cont.clone());
1420+
1421+
// Continue if false (not string)
1422+
asm.cmp(val, Qfalse.into());
1423+
asm.je(cont.clone());
1424+
1425+
let val = match val {
1426+
Opnd::Reg(_) | Opnd::VReg { .. } => val,
1427+
_ => asm.load(val),
1428+
};
1429+
1430+
let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS));
1431+
let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64));
1432+
asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64));
1433+
asm.je(side);
1434+
1435+
// Otherwise (non-string heap object), continue.
1436+
asm.write_label(cont);
1437+
} else {
1438+
unimplemented!("unsupported type: {guard_type}");
1439+
}
1440+
val
1441+
}
1442+
13901443
/// Compile an identity check with a side exit
13911444
fn gen_guard_bit_equals(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> lir::Opnd {
13921445
asm.cmp(val, Opnd::Value(expected));

zjit/src/cruby_bindings.inc.rs

Lines changed: 18 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

zjit/src/hir.rs

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ pub enum SideExitReason {
455455
FixnumSubOverflow,
456456
FixnumMultOverflow,
457457
GuardType(Type),
458+
GuardTypeNot(Type),
458459
GuardShape(ShapeId),
459460
GuardBitEquals(VALUE),
460461
PatchPoint(Invariant),
@@ -474,6 +475,7 @@ impl std::fmt::Display for SideExitReason {
474475
SideExitReason::UnknownNewarraySend(VM_OPT_NEWARRAY_SEND_PACK_BUFFER) => write!(f, "UnknownNewarraySend(PACK_BUFFER)"),
475476
SideExitReason::UnknownNewarraySend(VM_OPT_NEWARRAY_SEND_INCLUDE_P) => write!(f, "UnknownNewarraySend(INCLUDE_P)"),
476477
SideExitReason::GuardType(guard_type) => write!(f, "GuardType({guard_type})"),
478+
SideExitReason::GuardTypeNot(guard_type) => write!(f, "GuardTypeNot({guard_type})"),
477479
SideExitReason::GuardBitEquals(value) => write!(f, "GuardBitEquals({})", value.print(&PtrPrintMap::identity())),
478480
SideExitReason::PatchPoint(invariant) => write!(f, "PatchPoint({invariant})"),
479481
_ => write!(f, "{self:?}"),
@@ -623,6 +625,7 @@ pub enum Insn {
623625

624626
/// Side-exit if val doesn't have the expected type.
625627
GuardType { val: InsnId, guard_type: Type, state: InsnId },
628+
GuardTypeNot { val: InsnId, guard_type: Type, state: InsnId },
626629
/// Side-exit if val is not the expected VALUE.
627630
GuardBitEquals { val: InsnId, expected: VALUE, state: InsnId },
628631
/// Side-exit if val doesn't have the expected shape.
@@ -859,6 +862,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
859862
Insn::FixnumAnd { left, right, .. } => { write!(f, "FixnumAnd {left}, {right}") },
860863
Insn::FixnumOr { left, right, .. } => { write!(f, "FixnumOr {left}, {right}") },
861864
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) },
865+
Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) },
862866
Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) },
863867
&Insn::GuardShape { val, shape, .. } => { write!(f, "GuardShape {val}, {:p}", self.ptr_map.map_shape(shape)) },
864868
Insn::PatchPoint { invariant, .. } => { write!(f, "PatchPoint {}", invariant.print(self.ptr_map)) },
@@ -1285,6 +1289,7 @@ impl Function {
12851289
&IfTrue { val, ref target } => IfTrue { val: find!(val), target: find_branch_edge!(target) },
12861290
&IfFalse { val, ref target } => IfFalse { val: find!(val), target: find_branch_edge!(target) },
12871291
&GuardType { val, guard_type, state } => GuardType { val: find!(val), guard_type, state },
1292+
&GuardTypeNot { val, guard_type, state } => GuardTypeNot { val: find!(val), guard_type, state },
12881293
&GuardBitEquals { val, expected, state } => GuardBitEquals { val: find!(val), expected, state },
12891294
&GuardShape { val, shape, state } => GuardShape { val: find!(val), shape, state },
12901295
&FixnumAdd { left, right, state } => FixnumAdd { left: find!(left), right: find!(right), state },
@@ -1430,6 +1435,7 @@ impl Function {
14301435
Insn::ObjectAlloc { .. } => types::HeapObject,
14311436
Insn::CCall { return_type, .. } => *return_type,
14321437
Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
1438+
Insn::GuardTypeNot { .. } => types::BasicObject,
14331439
Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_value(*expected)),
14341440
Insn::GuardShape { val, .. } => self.type_of(*val),
14351441
Insn::FixnumAdd { .. } => types::Fixnum,
@@ -1546,9 +1552,10 @@ impl Function {
15461552
fn chase_insn(&self, insn: InsnId) -> InsnId {
15471553
let id = self.union_find.borrow().find_const(insn);
15481554
match self.insns[id.0] {
1549-
Insn::GuardType { val, .. } => self.chase_insn(val),
1550-
Insn::GuardShape { val, .. } => self.chase_insn(val),
1551-
Insn::GuardBitEquals { val, .. } => self.chase_insn(val),
1555+
Insn::GuardType { val, .. }
1556+
| Insn::GuardTypeNot { val, .. }
1557+
| Insn::GuardShape { val, .. }
1558+
| Insn::GuardBitEquals { val, .. } => self.chase_insn(val),
15521559
_ => id,
15531560
}
15541561
}
@@ -1791,12 +1798,26 @@ impl Function {
17911798
self.insn_types[replacement.0] = self.infer_type(replacement);
17921799
self.make_equal_to(insn_id, replacement);
17931800
}
1794-
Insn::ObjToString { val, .. } => {
1801+
Insn::ObjToString { val, cd, state, .. } => {
17951802
if self.is_a(val, types::String) {
17961803
// behaves differently from `SendWithoutBlock` with `mid:to_s` because ObjToString should not have a patch point for String to_s being redefined
1797-
self.make_equal_to(insn_id, val);
1804+
self.make_equal_to(insn_id, val); continue;
1805+
}
1806+
1807+
let frame_state = self.frame_state(state);
1808+
let Some(recv_type) = self.profiled_type_of_at(val, frame_state.insn_idx) else {
1809+
self.push_insn_id(block, insn_id); continue
1810+
};
1811+
1812+
if recv_type.is_string() {
1813+
let guard = self.push_insn(block, Insn::GuardType { val: val, guard_type: types::String, state: state });
1814+
// Infer type so AnyToString can fold off this
1815+
self.insn_types[guard.0] = self.infer_type(guard);
1816+
self.make_equal_to(insn_id, guard);
17981817
} else {
1799-
self.push_insn_id(block, insn_id);
1818+
self.push_insn(block, Insn::GuardTypeNot { val: val, guard_type: types::String, state: state});
1819+
let send_to_s = self.push_insn(block, Insn::SendWithoutBlock { self_val: val, cd: cd, args: vec![], state: state});
1820+
self.make_equal_to(insn_id, send_to_s);
18001821
}
18011822
}
18021823
Insn::AnyToString { str, .. } => {
@@ -2193,6 +2214,7 @@ impl Function {
21932214
| &Insn::StringCopy { val, state, .. }
21942215
| &Insn::ObjectAlloc { val, state }
21952216
| &Insn::GuardType { val, state, .. }
2217+
| &Insn::GuardTypeNot { val, state, .. }
21962218
| &Insn::GuardBitEquals { val, state, .. }
21972219
| &Insn::GuardShape { val, state, .. }
21982220
| &Insn::ToArray { val, state }
@@ -8281,6 +8303,71 @@ mod opt_tests {
82818303
");
82828304
}
82838305

8306+
#[test]
8307+
fn test_optimize_objtostring_anytostring_recv_profiled() {
8308+
eval("
8309+
def test(a)
8310+
\"#{a}\"
8311+
end
8312+
test('foo'); test('foo')
8313+
");
8314+
8315+
assert_snapshot!(hir_string("test"), @r"
8316+
fn test@<compiled>:3:
8317+
bb0(v0:BasicObject, v1:BasicObject):
8318+
v5:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
8319+
v17:String = GuardType v1, String
8320+
v11:StringExact = StringConcat v5, v17
8321+
CheckInterrupts
8322+
Return v11
8323+
");
8324+
}
8325+
8326+
#[test]
8327+
fn test_optimize_objtostring_anytostring_recv_profiled_string_subclass() {
8328+
eval("
8329+
class MyString < String; end
8330+
8331+
def test(a)
8332+
\"#{a}\"
8333+
end
8334+
foo = MyString.new('foo')
8335+
test(MyString.new(foo)); test(MyString.new(foo))
8336+
");
8337+
8338+
assert_snapshot!(hir_string("test"), @r"
8339+
fn test@<compiled>:5:
8340+
bb0(v0:BasicObject, v1:BasicObject):
8341+
v5:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
8342+
v17:String = GuardType v1, String
8343+
v11:StringExact = StringConcat v5, v17
8344+
CheckInterrupts
8345+
Return v11
8346+
");
8347+
}
8348+
8349+
#[test]
8350+
fn test_optimize_objtostring_profiled_nonstring_falls_back_to_send() {
8351+
eval("
8352+
def test(a)
8353+
\"#{a}\"
8354+
end
8355+
test([1,2,3]); test([1,2,3]) # No fast path for array
8356+
");
8357+
8358+
assert_snapshot!(hir_string("test"), @r"
8359+
fn test@<compiled>:3:
8360+
bb0(v0:BasicObject, v1:BasicObject):
8361+
v5:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
8362+
v17:BasicObject = GuardTypeNot v1, String
8363+
v18:BasicObject = SendWithoutBlock v1, :to_s
8364+
v9:String = AnyToString v1, str: v18
8365+
v11:StringExact = StringConcat v5, v9
8366+
CheckInterrupts
8367+
Return v11
8368+
");
8369+
}
8370+
82848371
#[test]
82858372
fn test_branchnil_nil() {
82868373
eval("

zjit/src/profile.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ fn profile_insn(bare_opcode: ruby_vminsn_type, ec: EcPtr) {
7575
YARVINSN_opt_empty_p => profile_operands(profiler, profile, 1),
7676
YARVINSN_opt_not => profile_operands(profiler, profile, 1),
7777
YARVINSN_getinstancevariable => profile_self(profiler, profile),
78+
YARVINSN_objtostring => profile_operands(profiler, profile, 1),
7879
YARVINSN_opt_send_without_block => {
7980
let cd: *const rb_call_data = profiler.insn_opnd(0).as_ptr();
8081
let argc = unsafe { vm_ci_argc((*cd).ci) };
@@ -235,6 +236,20 @@ impl ProfiledType {
235236
self.class == unsafe { rb_cInteger } && self.flags.is_immediate()
236237
}
237238

239+
pub fn is_string(&self) -> bool {
240+
// Fast paths for immediates and exact-class
241+
if self.flags.is_immediate() {
242+
return false;
243+
}
244+
245+
let string = unsafe { rb_cString };
246+
if self.class == string{
247+
return true;
248+
}
249+
250+
self.class.is_subclass_of(string) == ClassRelationship::Subclass
251+
}
252+
238253
pub fn is_flonum(&self) -> bool {
239254
self.class == unsafe { rb_cFloat } && self.flags.is_immediate()
240255
}

0 commit comments

Comments
 (0)