Skip to content

Commit e0bb3fb

Browse files
authored
ZJIT: Inline Integer#<< for constant rhs (ruby#15258)
This is good for protoboeuf and other binary parsing
1 parent 8728406 commit e0bb3fb

File tree

6 files changed

+151
-2
lines changed

6 files changed

+151
-2
lines changed

zjit/src/asm/arm64/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ pub fn lsl(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, shift: A64Opnd) {
649649

650650
ShiftImm::lsl(rd.reg_no, rn.reg_no, uimm as u8, rd.num_bits).into()
651651
},
652-
_ => panic!("Invalid operands combination to lsl instruction")
652+
_ => panic!("Invalid operands combination {rd:?} {rn:?} {shift:?} to lsl instruction")
653653
};
654654

655655
cb.write_bytes(&bytes);

zjit/src/codegen.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,12 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
402402
Insn::FixnumAnd { left, right } => gen_fixnum_and(asm, opnd!(left), opnd!(right)),
403403
Insn::FixnumOr { left, right } => gen_fixnum_or(asm, opnd!(left), opnd!(right)),
404404
Insn::FixnumXor { left, right } => gen_fixnum_xor(asm, opnd!(left), opnd!(right)),
405+
&Insn::FixnumLShift { left, right, state } => {
406+
// We only create FixnumLShift when we know the shift amount statically and it's in [0,
407+
// 63].
408+
let shift_amount = function.type_of(right).fixnum_value().unwrap() as u64;
409+
gen_fixnum_lshift(jit, asm, opnd!(left), shift_amount, &function.frame_state(state))
410+
}
405411
&Insn::FixnumMod { left, right, state } => gen_fixnum_mod(jit, asm, opnd!(left), opnd!(right), &function.frame_state(state)),
406412
Insn::IsNil { val } => gen_isnil(asm, opnd!(val)),
407413
&Insn::IsMethodCfunc { val, cd, cfunc, state: _ } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc),
@@ -1700,6 +1706,20 @@ fn gen_fixnum_xor(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd) -> lir
17001706
asm.add(out_val, Opnd::UImm(1))
17011707
}
17021708

1709+
/// Compile Fixnum << Fixnum
1710+
fn gen_fixnum_lshift(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, shift_amount: u64, state: &FrameState) -> lir::Opnd {
1711+
// Shift amount is known statically to be in the range [0, 63]
1712+
assert!(shift_amount < 64);
1713+
let in_val = asm.sub(left, Opnd::UImm(1)); // Drop tag bit
1714+
let out_val = asm.lshift(in_val, shift_amount.into());
1715+
let unshifted = asm.rshift(out_val, shift_amount.into());
1716+
asm.cmp(in_val, unshifted);
1717+
asm.jne(side_exit(jit, state, FixnumLShiftOverflow));
1718+
// Re-tag the output value
1719+
let out_val = asm.add(out_val, 1.into());
1720+
out_val
1721+
}
1722+
17031723
fn gen_fixnum_mod(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd {
17041724
// Check for left % 0, which raises ZeroDivisionError
17051725
asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0)));

zjit/src/cruby_methods.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ pub fn init() -> Annotations {
240240
annotate!(rb_cInteger, ">=", inline_integer_ge);
241241
annotate!(rb_cInteger, "<", inline_integer_lt);
242242
annotate!(rb_cInteger, "<=", inline_integer_le);
243+
annotate!(rb_cInteger, "<<", inline_integer_lshift);
243244
annotate!(rb_cString, "to_s", inline_string_to_s, types::StringExact);
244245
let thread_singleton = unsafe { rb_singleton_class(rb_cThread) };
245246
annotate!(thread_singleton, "current", inline_thread_current, types::BasicObject, no_gc, leaf);
@@ -546,6 +547,15 @@ fn inline_integer_le(fun: &mut hir::Function, block: hir::BlockId, recv: hir::In
546547
try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLe { left, right }, BOP_LE, recv, other, state)
547548
}
548549

550+
fn inline_integer_lshift(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> {
551+
let &[other] = args else { return None; };
552+
// Only convert to FixnumLShift if we know the shift amount is known at compile-time and could
553+
// plausibly create a fixnum.
554+
let Some(other_value) = fun.type_of(other).fixnum_value() else { return None; };
555+
if other_value < 0 || other_value > 63 { return None; }
556+
try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLShift { left, right, state }, BOP_LTLT, recv, other, state)
557+
}
558+
549559
fn inline_basic_object_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
550560
let &[other] = args else { return None; };
551561
let c_result = fun.push_insn(block, hir::Insn::IsBitEqual { left: recv, right: other });

zjit/src/hir.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ pub enum SideExitReason {
485485
FixnumAddOverflow,
486486
FixnumSubOverflow,
487487
FixnumMultOverflow,
488+
FixnumLShiftOverflow,
488489
GuardType(Type),
489490
GuardTypeNot(Type),
490491
GuardShape(ShapeId),
@@ -868,7 +869,7 @@ pub enum Insn {
868869
/// Non-local control flow. See the throw YARV instruction
869870
Throw { throw_state: u32, val: InsnId, state: InsnId },
870871

871-
/// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=, &, |, ^
872+
/// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=, &, |, ^, <<
872873
FixnumAdd { left: InsnId, right: InsnId, state: InsnId },
873874
FixnumSub { left: InsnId, right: InsnId, state: InsnId },
874875
FixnumMult { left: InsnId, right: InsnId, state: InsnId },
@@ -883,6 +884,7 @@ pub enum Insn {
883884
FixnumAnd { left: InsnId, right: InsnId },
884885
FixnumOr { left: InsnId, right: InsnId },
885886
FixnumXor { left: InsnId, right: InsnId },
887+
FixnumLShift { left: InsnId, right: InsnId, state: InsnId },
886888

887889
// Distinct from `SendWithoutBlock` with `mid:to_s` because does not have a patch point for String to_s being redefined
888890
ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId },
@@ -979,6 +981,7 @@ impl Insn {
979981
Insn::FixnumAnd { .. } => false,
980982
Insn::FixnumOr { .. } => false,
981983
Insn::FixnumXor { .. } => false,
984+
Insn::FixnumLShift { .. } => false,
982985
Insn::GetLocal { .. } => false,
983986
Insn::IsNil { .. } => false,
984987
Insn::LoadPC => false,
@@ -1218,6 +1221,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
12181221
Insn::FixnumAnd { left, right, .. } => { write!(f, "FixnumAnd {left}, {right}") },
12191222
Insn::FixnumOr { left, right, .. } => { write!(f, "FixnumOr {left}, {right}") },
12201223
Insn::FixnumXor { left, right, .. } => { write!(f, "FixnumXor {left}, {right}") },
1224+
Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") },
12211225
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) },
12221226
Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) },
12231227
Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) },
@@ -1836,6 +1840,7 @@ impl Function {
18361840
&FixnumAnd { left, right } => FixnumAnd { left: find!(left), right: find!(right) },
18371841
&FixnumOr { left, right } => FixnumOr { left: find!(left), right: find!(right) },
18381842
&FixnumXor { left, right } => FixnumXor { left: find!(left), right: find!(right) },
1843+
&FixnumLShift { left, right, state } => FixnumLShift { left: find!(left), right: find!(right), state },
18391844
&ObjToString { val, cd, state } => ObjToString {
18401845
val: find!(val),
18411846
cd,
@@ -2054,6 +2059,7 @@ impl Function {
20542059
Insn::FixnumAnd { .. } => types::Fixnum,
20552060
Insn::FixnumOr { .. } => types::Fixnum,
20562061
Insn::FixnumXor { .. } => types::Fixnum,
2062+
Insn::FixnumLShift { .. } => types::Fixnum,
20572063
Insn::PutSpecialObject { .. } => types::BasicObject,
20582064
Insn::SendWithoutBlock { .. } => types::BasicObject,
20592065
Insn::SendWithoutBlockDirect { .. } => types::BasicObject,
@@ -3506,6 +3512,7 @@ impl Function {
35063512
| &Insn::FixnumDiv { left, right, state }
35073513
| &Insn::FixnumMod { left, right, state }
35083514
| &Insn::ArrayExtend { left, right, state }
3515+
| &Insn::FixnumLShift { left, right, state }
35093516
=> {
35103517
worklist.push_back(left);
35113518
worklist.push_back(right);
@@ -4271,6 +4278,7 @@ impl Function {
42714278
| Insn::FixnumAnd { left, right }
42724279
| Insn::FixnumOr { left, right }
42734280
| Insn::FixnumXor { left, right }
4281+
| Insn::FixnumLShift { left, right, .. }
42744282
| Insn::NewRangeFixnum { low: left, high: right, .. }
42754283
=> {
42764284
self.assert_subtype(insn_id, left, types::Fixnum)?;

zjit/src/hir/opt_tests.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6466,6 +6466,115 @@ mod hir_opt_tests {
64666466
");
64676467
}
64686468

6469+
#[test]
6470+
fn test_inline_integer_ltlt_with_known_fixnum() {
6471+
eval("
6472+
def test(x) = x << 5
6473+
test(4)
6474+
");
6475+
assert_contains_opcode("test", YARVINSN_opt_ltlt);
6476+
assert_snapshot!(hir_string("test"), @r"
6477+
fn test@<compiled>:2:
6478+
bb0():
6479+
EntryPoint interpreter
6480+
v1:BasicObject = LoadSelf
6481+
v2:BasicObject = GetLocal l0, SP@4
6482+
Jump bb2(v1, v2)
6483+
bb1(v5:BasicObject, v6:BasicObject):
6484+
EntryPoint JIT(0)
6485+
Jump bb2(v5, v6)
6486+
bb2(v8:BasicObject, v9:BasicObject):
6487+
v14:Fixnum[5] = Const Value(5)
6488+
PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
6489+
v24:Fixnum = GuardType v9, Fixnum
6490+
v25:Fixnum = FixnumLShift v24, v14
6491+
IncrCounter inline_cfunc_optimized_send_count
6492+
CheckInterrupts
6493+
Return v25
6494+
");
6495+
}
6496+
6497+
#[test]
6498+
fn test_dont_inline_integer_ltlt_with_negative() {
6499+
eval("
6500+
def test(x) = x << -5
6501+
test(4)
6502+
");
6503+
assert_contains_opcode("test", YARVINSN_opt_ltlt);
6504+
assert_snapshot!(hir_string("test"), @r"
6505+
fn test@<compiled>:2:
6506+
bb0():
6507+
EntryPoint interpreter
6508+
v1:BasicObject = LoadSelf
6509+
v2:BasicObject = GetLocal l0, SP@4
6510+
Jump bb2(v1, v2)
6511+
bb1(v5:BasicObject, v6:BasicObject):
6512+
EntryPoint JIT(0)
6513+
Jump bb2(v5, v6)
6514+
bb2(v8:BasicObject, v9:BasicObject):
6515+
v14:Fixnum[-5] = Const Value(-5)
6516+
PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
6517+
v24:Fixnum = GuardType v9, Fixnum
6518+
v25:BasicObject = CCallWithFrame Integer#<<@0x1038, v24, v14
6519+
CheckInterrupts
6520+
Return v25
6521+
");
6522+
}
6523+
6524+
#[test]
6525+
fn test_dont_inline_integer_ltlt_with_out_of_range() {
6526+
eval("
6527+
def test(x) = x << 64
6528+
test(4)
6529+
");
6530+
assert_contains_opcode("test", YARVINSN_opt_ltlt);
6531+
assert_snapshot!(hir_string("test"), @r"
6532+
fn test@<compiled>:2:
6533+
bb0():
6534+
EntryPoint interpreter
6535+
v1:BasicObject = LoadSelf
6536+
v2:BasicObject = GetLocal l0, SP@4
6537+
Jump bb2(v1, v2)
6538+
bb1(v5:BasicObject, v6:BasicObject):
6539+
EntryPoint JIT(0)
6540+
Jump bb2(v5, v6)
6541+
bb2(v8:BasicObject, v9:BasicObject):
6542+
v14:Fixnum[64] = Const Value(64)
6543+
PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
6544+
v24:Fixnum = GuardType v9, Fixnum
6545+
v25:BasicObject = CCallWithFrame Integer#<<@0x1038, v24, v14
6546+
CheckInterrupts
6547+
Return v25
6548+
");
6549+
}
6550+
6551+
#[test]
6552+
fn test_dont_inline_integer_ltlt_with_unknown_fixnum() {
6553+
eval("
6554+
def test(x, y) = x << y
6555+
test(4, 5)
6556+
");
6557+
assert_contains_opcode("test", YARVINSN_opt_ltlt);
6558+
assert_snapshot!(hir_string("test"), @r"
6559+
fn test@<compiled>:2:
6560+
bb0():
6561+
EntryPoint interpreter
6562+
v1:BasicObject = LoadSelf
6563+
v2:BasicObject = GetLocal l0, SP@5
6564+
v3:BasicObject = GetLocal l0, SP@4
6565+
Jump bb2(v1, v2, v3)
6566+
bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject):
6567+
EntryPoint JIT(0)
6568+
Jump bb2(v6, v7, v8)
6569+
bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject):
6570+
PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
6571+
v26:Fixnum = GuardType v11, Fixnum
6572+
v27:BasicObject = CCallWithFrame Integer#<<@0x1038, v26, v12
6573+
CheckInterrupts
6574+
Return v27
6575+
");
6576+
}
6577+
64696578
#[test]
64706579
fn test_optimize_string_append() {
64716580
eval(r#"

zjit/src/stats.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ make_counters! {
142142
exit_fixnum_add_overflow,
143143
exit_fixnum_sub_overflow,
144144
exit_fixnum_mult_overflow,
145+
exit_fixnum_lshift_overflow,
145146
exit_fixnum_mod_by_zero,
146147
exit_box_fixnum_overflow,
147148
exit_guard_type_failure,
@@ -423,6 +424,7 @@ pub fn side_exit_counter(reason: crate::hir::SideExitReason) -> Counter {
423424
FixnumAddOverflow => exit_fixnum_add_overflow,
424425
FixnumSubOverflow => exit_fixnum_sub_overflow,
425426
FixnumMultOverflow => exit_fixnum_mult_overflow,
427+
FixnumLShiftOverflow => exit_fixnum_lshift_overflow,
426428
FixnumModByZero => exit_fixnum_mod_by_zero,
427429
BoxFixnumOverflow => exit_box_fixnum_overflow,
428430
GuardType(_) => exit_guard_type_failure,

0 commit comments

Comments
 (0)