Skip to content

Commit 5327f78

Browse files
committed
Merge pull request atomvm#2213 from pguyot/w12/jit-inline-mul-bsl-bsr
JIT: inline mul/bsl/bsr BIF operations when provably safe Continuation of: - atomvm#2203 These changes are made under both the "Apache 2.0" and the "GNU Lesser General Public License 2.1 or later" license terms (dual license). SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
2 parents f60c037 + 8323034 commit 5327f78

File tree

6 files changed

+341
-10
lines changed

6 files changed

+341
-10
lines changed

libs/jit/src/jit.erl

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3359,6 +3359,65 @@ op_gc_bif2(
33593359
Arg2Value = Arg2 bsr 4,
33603360
Range2 = {Arg2Value, Arg2Value},
33613361
op_gc_bif2_bxor(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Range2);
3362+
% mul - both typed integers with range: inline if proven small
3363+
op_gc_bif2(
3364+
MMod,
3365+
MSt0,
3366+
FailLabel,
3367+
Live,
3368+
Bif,
3369+
erlang,
3370+
'*',
3371+
{typed, Arg1, {t_integer, Range1}},
3372+
{typed, Arg2, {t_integer, Range2}},
3373+
Dest
3374+
) ->
3375+
op_gc_bif2_mul(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Range2);
3376+
op_gc_bif2(
3377+
MMod,
3378+
MSt0,
3379+
FailLabel,
3380+
Live,
3381+
Bif,
3382+
erlang,
3383+
'*',
3384+
{typed, Arg1, {t_integer, Range1}},
3385+
Arg2,
3386+
Dest
3387+
) when is_integer(Arg2), Arg2 band ?TERM_IMMED_TAG_MASK =:= ?TERM_INTEGER_TAG ->
3388+
Arg2Value = Arg2 bsr 4,
3389+
Range2 = {Arg2Value, Arg2Value},
3390+
op_gc_bif2_mul(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Range2);
3391+
% bsl - typed integer with literal shift amount: inline if result fits
3392+
op_gc_bif2(
3393+
MMod,
3394+
MSt0,
3395+
FailLabel,
3396+
Live,
3397+
Bif,
3398+
erlang,
3399+
'bsl',
3400+
{typed, Arg1, {t_integer, Range1}},
3401+
Arg2,
3402+
Dest
3403+
) when is_integer(Arg2), Arg2 band ?TERM_IMMED_TAG_MASK =:= ?TERM_INTEGER_TAG ->
3404+
Arg2Value = Arg2 bsr 4,
3405+
op_gc_bif2_bsl(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Arg2Value);
3406+
% bsr - typed integer with literal shift amount: inline if non-negative and small
3407+
op_gc_bif2(
3408+
MMod,
3409+
MSt0,
3410+
FailLabel,
3411+
Live,
3412+
Bif,
3413+
erlang,
3414+
'bsr',
3415+
{typed, Arg1, {t_integer, Range1}},
3416+
Arg2,
3417+
Dest
3418+
) when is_integer(Arg2), Arg2 band ?TERM_IMMED_TAG_MASK =:= ?TERM_INTEGER_TAG ->
3419+
Arg2Value = Arg2 bsr 4,
3420+
op_gc_bif2_bsr(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Arg2Value);
33623421
% Default case
33633422
op_gc_bif2(
33643423
MMod, MSt0, FailLabel, Live, Bif, _Module, _Function, {typed, Arg1, _}, {typed, Arg2, _}, Dest
@@ -3587,6 +3646,160 @@ op_gc_bif2_bxor(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Rang
35873646
op_gc_bif2_default(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest)
35883647
end.
35893648

3649+
% Check if multiplication can be inlined based on type ranges
3650+
% Returns true if the result is guaranteed to fit in a small integer
3651+
can_inline_mul(Range1, Range2, MMod) ->
3652+
{MinSafe, MaxSafe} = small_integer_bounds(MMod),
3653+
case {Range1, Range2} of
3654+
{{Min1, Max1}, {Min2, Max2}} when
3655+
is_integer(Min1),
3656+
is_integer(Max1),
3657+
is_integer(Min2),
3658+
is_integer(Max2)
3659+
->
3660+
% For multiplication, all four corner products must be checked
3661+
Products = [Min1 * Min2, Min1 * Max2, Max1 * Min2, Max1 * Max2],
3662+
MinResult = lists:min(Products),
3663+
MaxResult = lists:max(Products),
3664+
MinResult >= MinSafe andalso MaxResult =< MaxSafe;
3665+
_ ->
3666+
false
3667+
end.
3668+
3669+
% Optimized multiplication with compile-time range checking
3670+
op_gc_bif2_mul(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Range2) when
3671+
is_integer(Arg2)
3672+
->
3673+
case can_inline_mul(Range1, Range2, MMod) of
3674+
true ->
3675+
Arg2Value = Arg2 bsr 4,
3676+
case Arg2Value of
3677+
C when C > 1 ->
3678+
% Strip tag, multiply by constant, re-tag
3679+
{MSt1, Reg} = MMod:move_to_native_register(MSt0, Arg1),
3680+
{MSt2, Reg} = MMod:and_(MSt1, {free, Reg}, bnot (?TERM_IMMED_TAG_MASK)),
3681+
MSt3 = MMod:mul(MSt2, Reg, C),
3682+
MSt4 = MMod:or_(MSt3, Reg, ?TERM_INTEGER_TAG),
3683+
MSt5 = MMod:move_to_vm_register(MSt4, Reg, Dest),
3684+
MMod:free_native_registers(MSt5, [Reg, Dest]);
3685+
_ ->
3686+
% 0 or 1 would need special handling (0 produces wrong
3687+
% tag, 1 is identity), and negative constants require
3688+
% sign-aware logic. The compiler typically folds these,
3689+
% but fall back defensively.
3690+
op_gc_bif2_default(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest)
3691+
end;
3692+
false ->
3693+
op_gc_bif2_default(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest)
3694+
end;
3695+
op_gc_bif2_mul(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, Range2) ->
3696+
case can_inline_mul(Range1, Range2, MMod) of
3697+
true ->
3698+
% Both operands in registers: strip tags, extract value, multiply
3699+
{MSt1, Reg1} = MMod:move_to_native_register(MSt0, Arg1),
3700+
{MSt2, Reg2} = MMod:move_to_native_register(MSt1, Arg2),
3701+
% Strip tag from Reg1: value1 << 4
3702+
{MSt3, Reg1} = MMod:and_(MSt2, {free, Reg1}, bnot (?TERM_IMMED_TAG_MASK)),
3703+
% Strip tag from Reg2 and shift right by 4 to get raw value2
3704+
{MSt4, Reg2} = MMod:and_(MSt3, {free, Reg2}, bnot (?TERM_IMMED_TAG_MASK)),
3705+
{MSt5, Reg2} = MMod:shift_right(MSt4, {free, Reg2}, 4),
3706+
% Multiply: (value1 << 4) * value2 = (value1 * value2) << 4
3707+
MSt6 = MMod:mul(MSt5, Reg1, Reg2),
3708+
% Add tag back
3709+
MSt7 = MMod:or_(MSt6, Reg1, ?TERM_INTEGER_TAG),
3710+
MSt8 = MMod:move_to_vm_register(MSt7, Reg1, Dest),
3711+
MMod:free_native_registers(MSt8, [Reg1, Reg2, Dest]);
3712+
false ->
3713+
op_gc_bif2_default(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest)
3714+
end.
3715+
3716+
% Check if left shift can be inlined based on type range and shift amount
3717+
can_inline_bsl(Range1, ShiftAmount, MMod) ->
3718+
{MinSafe, MaxSafe} = small_integer_bounds(MMod),
3719+
case Range1 of
3720+
{Min1, Max1} when
3721+
is_integer(Min1),
3722+
is_integer(Max1),
3723+
ShiftAmount >= 0
3724+
->
3725+
MinResult = Min1 bsl ShiftAmount,
3726+
MaxResult = Max1 bsl ShiftAmount,
3727+
MinResult >= MinSafe andalso MaxResult =< MaxSafe;
3728+
_ ->
3729+
false
3730+
end.
3731+
3732+
% Optimized bsl with compile-time range checking
3733+
op_gc_bif2_bsl(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, ShiftAmount) ->
3734+
case can_inline_bsl(Range1, ShiftAmount, MMod) of
3735+
true ->
3736+
case ShiftAmount of
3737+
0 ->
3738+
% No shift - just copy
3739+
{MSt1, Reg} = MMod:move_to_native_register(MSt0, Arg1),
3740+
MSt2 = MMod:move_to_vm_register(MSt1, Reg, Dest),
3741+
MMod:free_native_registers(MSt2, [Reg, Dest]);
3742+
_ ->
3743+
% Strip tag, shift left, re-tag
3744+
{MSt1, Reg} = MMod:move_to_native_register(MSt0, Arg1),
3745+
{MSt2, Reg} = MMod:and_(MSt1, {free, Reg}, bnot (?TERM_IMMED_TAG_MASK)),
3746+
MSt3 = MMod:shift_left(MSt2, Reg, ShiftAmount),
3747+
MSt4 = MMod:or_(MSt3, Reg, ?TERM_INTEGER_TAG),
3748+
MSt5 = MMod:move_to_vm_register(MSt4, Reg, Dest),
3749+
MMod:free_native_registers(MSt5, [Reg, Dest])
3750+
end;
3751+
false ->
3752+
op_gc_bif2_default(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest)
3753+
end.
3754+
3755+
% Check if right shift can be inlined
3756+
% Only safe for non-negative inputs (the generated native code uses logical
3757+
% shift right, which does not preserve sign for negative values)
3758+
can_inline_bsr(Range1, ShiftAmount, MMod) ->
3759+
{_MinSafe, MaxSafe} = small_integer_bounds(MMod),
3760+
% Ensure (ShiftAmount + 4) does not exceed register width
3761+
% (would be undefined behavior in native shift)
3762+
WordBits = MMod:word_size() * 8,
3763+
case Range1 of
3764+
{Min1, Max1} when
3765+
is_integer(Min1),
3766+
is_integer(Max1),
3767+
Min1 >= 0,
3768+
ShiftAmount >= 0,
3769+
ShiftAmount + 4 < WordBits
3770+
->
3771+
% Non-negative input: right shift can only reduce magnitude
3772+
Max1 =< MaxSafe;
3773+
_ ->
3774+
false
3775+
end.
3776+
3777+
% Optimized bsr with compile-time range checking
3778+
op_gc_bif2_bsr(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest, Range1, ShiftAmount) ->
3779+
case can_inline_bsr(Range1, ShiftAmount, MMod) of
3780+
true ->
3781+
case ShiftAmount of
3782+
0 ->
3783+
% No shift - just copy
3784+
{MSt1, Reg} = MMod:move_to_native_register(MSt0, Arg1),
3785+
MSt2 = MMod:move_to_vm_register(MSt1, Reg, Dest),
3786+
MMod:free_native_registers(MSt2, [Reg, Dest]);
3787+
_ ->
3788+
% For non-negative values: shift right by (S+4), shift left by 4, re-tag.
3789+
% This avoids a separate tag-stripping instruction: the combined
3790+
% shift (S+4) removes both the 4 tag bits and applies the S-bit
3791+
% shift in one operation. The tag bits get shifted away since S+4 >= 5.
3792+
{MSt1, Reg} = MMod:move_to_native_register(MSt0, Arg1),
3793+
{MSt2, Reg} = MMod:shift_right(MSt1, {free, Reg}, ShiftAmount + 4),
3794+
MSt3 = MMod:shift_left(MSt2, Reg, 4),
3795+
MSt4 = MMod:or_(MSt3, Reg, ?TERM_INTEGER_TAG),
3796+
MSt5 = MMod:move_to_vm_register(MSt4, Reg, Dest),
3797+
MMod:free_native_registers(MSt5, [Reg, Dest])
3798+
end;
3799+
false ->
3800+
op_gc_bif2_default(MMod, MSt0, FailLabel, Live, Bif, Arg1, Arg2, Dest)
3801+
end.
3802+
35903803
% Helper to unwrap typed arguments
35913804
unwrap_typed({typed, Arg, _Type}) -> Arg;
35923805
unwrap_typed(Arg) -> Arg.

libs/jit/src/jit_aarch64.erl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2515,10 +2515,10 @@ sub(#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State
25152515
%% @end
25162516
%% @param State current backend state
25172517
%% @param Reg register to multiply
2518-
%% @param Val constant multiplier (non-negative integer)
2518+
%% @param Val multiplier (an integer constant or a register)
25192519
%% @return Updated backend state
25202520
%%-----------------------------------------------------------------------------
2521-
-spec mul(state(), aarch64_register(), non_neg_integer()) -> state().
2521+
-spec mul(state(), aarch64_register(), integer() | aarch64_register()) -> state().
25222522
mul(State, _Reg, 1) ->
25232523
State;
25242524
mul(State, Reg, 2) ->
@@ -2579,12 +2579,19 @@ mul(
25792579
State,
25802580
Reg,
25812581
Val
2582-
) ->
2582+
) when is_integer(Val) ->
25832583
Temp = first_avail(Avail),
25842584
I1 = jit_aarch64_asm:mov(Temp, Val),
25852585
I2 = jit_aarch64_asm:mul(Reg, Reg, Temp),
25862586
Stream1 = StreamModule:append(Stream0, <<I1/binary, I2/binary>>),
25872587
Regs1 = jit_regs:invalidate_reg(jit_regs:invalidate_reg(Regs0, Temp), Reg),
2588+
State#state{stream = Stream1, regs = Regs1};
2589+
mul(
2590+
#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State, DestReg, SrcReg
2591+
) when is_atom(SrcReg) ->
2592+
I1 = jit_aarch64_asm:mul(DestReg, DestReg, SrcReg),
2593+
Stream1 = StreamModule:append(Stream0, I1),
2594+
Regs1 = jit_regs:invalidate_reg(Regs0, DestReg),
25882595
State#state{stream = Stream1, regs = Regs1}.
25892596

25902597
%%-----------------------------------------------------------------------------

libs/jit/src/jit_armv6m.erl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3588,6 +3588,7 @@ sub(#state{stream_module = StreamModule, available_regs = Avail, regs = Regs0} =
35883588
Regs1 = jit_regs:invalidate_reg(jit_regs:invalidate_reg(Regs0, Reg), Temp),
35893589
State1#state{available_regs = Avail, stream = Stream2, regs = Regs1}.
35903590

3591+
-spec mul(state(), armv6m_register(), integer() | armv6m_register()) -> state().
35913592
mul(State, _Reg, 1) ->
35923593
State;
35933594
mul(State, Reg, 2) ->
@@ -3647,7 +3648,7 @@ mul(
36473648
#state{stream_module = StreamModule, available_regs = Avail, regs = Regs0} = State0,
36483649
Reg,
36493650
Val
3650-
) ->
3651+
) when is_integer(Val) ->
36513652
Temp = first_avail(Avail),
36523653
TempBit = reg_bit(Temp),
36533654
AT = Avail band (bnot TempBit),
@@ -3658,7 +3659,14 @@ mul(
36583659
Regs1 = jit_regs:invalidate_reg(jit_regs:invalidate_reg(Regs0, Temp), Reg),
36593660
State1#state{
36603661
stream = Stream2, available_regs = State1#state.available_regs bor TempBit, regs = Regs1
3661-
}.
3662+
};
3663+
mul(
3664+
#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State, DestReg, SrcReg
3665+
) when is_atom(SrcReg) ->
3666+
I = jit_armv6m_asm:muls(DestReg, SrcReg),
3667+
Stream1 = StreamModule:append(Stream0, I),
3668+
Regs1 = jit_regs:invalidate_reg(Regs0, DestReg),
3669+
State#state{stream = Stream1, regs = Regs1}.
36623670

36633671
%%
36643672
%% Analysis of AArch64 pattern and ARM Thumb mapping:

libs/jit/src/jit_riscv32.erl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3117,6 +3117,7 @@ sub(#state{stream_module = StreamModule, available_regs = Avail, regs = Regs0} =
31173117
Regs1 = jit_regs:invalidate_reg(jit_regs:invalidate_reg(Regs0, Reg), Temp),
31183118
State1#state{available_regs = Avail, stream = Stream2, regs = Regs1}.
31193119

3120+
-spec mul(state(), riscv32_register(), integer() | riscv32_register()) -> state().
31203121
mul(State, _Reg, 1) ->
31213122
State;
31223123
mul(State, Reg, 2) ->
@@ -3176,7 +3177,7 @@ mul(
31763177
#state{stream_module = StreamModule, available_regs = Avail, regs = Regs0} = State0,
31773178
Reg,
31783179
Val
3179-
) ->
3180+
) when is_integer(Val) ->
31803181
Temp = first_avail(Avail),
31813182
AT = Avail band (bnot reg_bit(Temp)),
31823183
State1 = mov_immediate(State0#state{available_regs = AT}, Temp, Val),
@@ -3188,7 +3189,14 @@ mul(
31883189
stream = Stream2,
31893190
available_regs = State1#state.available_regs bor reg_bit(Temp),
31903191
regs = Regs1
3191-
}.
3192+
};
3193+
mul(
3194+
#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State, DestReg, SrcReg
3195+
) when is_atom(SrcReg) ->
3196+
I = jit_riscv32_asm:mul(DestReg, DestReg, SrcReg),
3197+
Stream1 = StreamModule:append(Stream0, I),
3198+
Regs1 = jit_regs:invalidate_reg(Regs0, DestReg),
3199+
State#state{stream = Stream1, regs = Regs1}.
31923200

31933201
%%
31943202
%% RISC-V32 implementation (no prolog/epilog needed due to 32 registers):

libs/jit/src/jit_x86_64.erl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,6 +2564,7 @@ sub(#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State
25642564
Regs1 = jit_regs:invalidate_reg(Regs0, Reg),
25652565
State#state{stream = Stream1, regs = Regs1}.
25662566

2567+
-spec mul(state(), x86_64_register(), integer() | x86_64_register()) -> state().
25672568
mul(State, _Reg, 1) ->
25682569
State;
25692570
mul(State, Reg, 2) ->
@@ -2584,17 +2585,26 @@ mul(
25842585
} = State,
25852586
Reg,
25862587
Val
2587-
) when Val < -16#80000000 orelse Val > 16#7FFFFFFF ->
2588+
) when is_integer(Val), (Val < -16#80000000 orelse Val > 16#7FFFFFFF) ->
25882589
TempReg = first_avail(Avail),
25892590
I1 = jit_x86_64_asm:movabsq(Val, TempReg),
25902591
I2 = jit_x86_64_asm:imulq(TempReg, Reg),
25912592
Stream1 = StreamModule:append(Stream0, <<I1/binary, I2/binary>>),
25922593
Regs1 = jit_regs:invalidate_reg(jit_regs:invalidate_reg(Regs0, TempReg), Reg),
25932594
State#state{stream = Stream1, regs = Regs1};
2594-
mul(#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State, Reg, Val) ->
2595+
mul(#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State, Reg, Val) when
2596+
is_integer(Val)
2597+
->
25952598
I1 = jit_x86_64_asm:imulq(Val, Reg),
25962599
Stream1 = StreamModule:append(Stream0, I1),
25972600
Regs1 = jit_regs:invalidate_reg(Regs0, Reg),
2601+
State#state{stream = Stream1, regs = Regs1};
2602+
mul(
2603+
#state{stream_module = StreamModule, stream = Stream0, regs = Regs0} = State, DestReg, SrcReg
2604+
) when is_atom(SrcReg) ->
2605+
I1 = jit_x86_64_asm:imulq(SrcReg, DestReg),
2606+
Stream1 = StreamModule:append(Stream0, I1),
2607+
Regs1 = jit_regs:invalidate_reg(Regs0, DestReg),
25982608
State#state{stream = Stream1, regs = Regs1}.
25992609

26002610
%% Signed integer division: quotient = DividendReg / DivisorReg

0 commit comments

Comments
 (0)