Skip to content

Commit d522620

Browse files
committed
Merge pull request #1834 from pguyot/w39/jit-optimize-term_to_int-using-types
JIT: optimize term to int using types These changes are made under both the "Apache 2.0" and the "GNU Lesser General Public License 2.1 or later" license terms (dual license). SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
2 parents 843a3f8 + 1188cea commit d522620

File tree

14 files changed

+591
-21
lines changed

14 files changed

+591
-21
lines changed

libs/estdlib/src/code_server.erl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
code_chunk/1,
3939
atom_resolver/2,
4040
literal_resolver/2,
41+
type_resolver/2,
4142
set_native_code/3
4243
]).
4344

@@ -126,6 +127,14 @@ atom_resolver(_Module, _Index) ->
126127
literal_resolver(_Module, _Index) ->
127128
erlang:nif_error(undefined).
128129

130+
%% @doc Get a type from its index
131+
%% @return The type information
132+
%% @param Module module get a type from
133+
%% @param Index type index in the module
134+
-spec type_resolver(Module :: module(), Index :: non_neg_integer()) -> any().
135+
type_resolver(_Module, _Index) ->
136+
erlang:nif_error(undefined).
137+
129138
%% @doc Associate a native code stream with a module
130139
%% @return ok
131140
%% @param Module module to set the native code of
@@ -154,10 +163,16 @@ load(Module) ->
154163
LiteralResolver = fun(Index) ->
155164
code_server:literal_resolver(Module, Index)
156165
end,
166+
TypeResolver = fun(Index) -> code_server:type_resolver(Module, Index) end,
157167
Stream0 = jit:stream(jit_mmap_size(byte_size(Code))),
158168
{BackendModule, BackendState0} = jit:backend(Stream0),
159169
{LabelsCount, BackendState1} = jit:compile(
160-
Code, AtomResolver, LiteralResolver, BackendModule, BackendState0
170+
Code,
171+
AtomResolver,
172+
LiteralResolver,
173+
TypeResolver,
174+
BackendModule,
175+
BackendState0
161176
),
162177
Stream1 = BackendModule:stream(BackendState1),
163178
code_server:set_native_code(Module, LabelsCount, Stream1),

libs/jit/src/jit.erl

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
stream/1,
2525
backend/1,
2626
beam_chunk_header/3,
27-
compile/5
27+
compile/6
2828
]).
2929

3030
% NIFs
@@ -98,7 +98,8 @@
9898
line_offsets :: [{integer(), integer()}],
9999
labels_count :: pos_integer(),
100100
atom_resolver :: fun((integer()) -> atom()),
101-
literal_resolver :: fun((integer()) -> any())
101+
literal_resolver :: fun((integer()) -> any()),
102+
type_resolver :: fun((integer()) -> any())
102103
}).
103104

104105
-type stream() :: any().
@@ -130,6 +131,7 @@ compile(
130131
<<16:32, 0:32, OpcodeMax:32, LabelsCount:32, _FunctionsCount:32, Opcodes/binary>>,
131132
AtomResolver,
132133
LiteralResolver,
134+
TypeResolver,
133135
MMod,
134136
MSt0
135137
) when OpcodeMax =< ?OPCODE_MAX ->
@@ -138,7 +140,8 @@ compile(
138140
line_offsets = [],
139141
labels_count = LabelsCount,
140142
atom_resolver = AtomResolver,
141-
literal_resolver = LiteralResolver
143+
literal_resolver = LiteralResolver,
144+
type_resolver = TypeResolver
142145
},
143146
{State1, MSt2} = first_pass(Opcodes, MMod, MSt1, State0),
144147
MSt3 = second_pass(MMod, MSt2, State1),
@@ -147,11 +150,12 @@ compile(
147150
<<16:32, 0:32, OpcodeMax:32, _LabelsCount:32, _FunctionsCount:32, _Opcodes/binary>>,
148151
_AtomResolver,
149152
_LiteralResolver,
153+
_TypeResolver,
150154
_MMod,
151155
_MSt
152156
) ->
153157
error(badarg, [OpcodeMax]);
154-
compile(CodeChunk, _AtomResolver, _LiteralResolver, _MMod, _MSt) ->
158+
compile(CodeChunk, _AtomResolver, _LiteralResolver, _TypeResolver, _MMod, _MSt) ->
155159
error(badarg, [CodeChunk]).
156160

157161
% 1
@@ -1143,7 +1147,7 @@ first_pass(<<?OP_IS_FUNCTION2, Rest0/binary>>, MMod, MSt0, State0) ->
11431147
?ASSERT_ALL_NATIVE_FREE(MSt0),
11441148
{Label, Rest1} = decode_label(Rest0),
11451149
{MSt1, Arg1, Rest2} = decode_compact_term(Rest1, MMod, MSt0, State0),
1146-
{MSt2, ArityTerm, Rest3} = decode_compact_term(Rest2, MMod, MSt1, State0),
1150+
{MSt2, ArityTerm, Rest3} = decode_typed_compact_term(Rest2, MMod, MSt1, State0),
11471151
?TRACE("OP_IS_FUNCTION2 ~p,~p,~p\n", [Label, Arg1, ArityTerm]),
11481152
{MSt3, FuncPtr} = term_is_boxed_with_tag_and_get_ptr(Label, Arg1, ?TERM_BOXED_FUN, MMod, MSt2),
11491153
{MSt4, Arity} = term_to_int(ArityTerm, Label, MMod, MSt3),
@@ -1174,7 +1178,7 @@ first_pass(<<?OP_BS_GET_INTEGER2, Rest0/binary>>, MMod, MSt0, State0) ->
11741178
{Fail, Rest1} = decode_label(Rest0),
11751179
{MSt1, Src, Rest2} = decode_compact_term(Rest1, MMod, MSt0, State0),
11761180
{_Live, Rest3} = decode_literal(Rest2),
1177-
{MSt2, Size, Rest4} = decode_compact_term(Rest3, MMod, MSt1, State0),
1181+
{MSt2, Size, Rest4} = decode_typed_compact_term(Rest3, MMod, MSt1, State0),
11781182
{Unit, Rest5} = decode_literal(Rest4),
11791183
{FlagsValue, Rest6} = decode_literal(Rest5),
11801184
{MSt3, SrcReg} = MMod:move_to_native_register(MSt2, Src),
@@ -1213,7 +1217,7 @@ first_pass(<<?OP_BS_GET_FLOAT2, Rest0/binary>>, MMod, MSt0, State0) ->
12131217
{Fail, Rest1} = decode_label(Rest0),
12141218
{MSt1, Src, Rest2} = decode_compact_term(Rest1, MMod, MSt0, State0),
12151219
{_Live, Rest3} = decode_literal(Rest2),
1216-
{MSt2, Size, Rest4} = decode_compact_term(Rest3, MMod, MSt1, State0),
1220+
{MSt2, Size, Rest4} = decode_typed_compact_term(Rest3, MMod, MSt1, State0),
12171221
{Unit, Rest5} = decode_literal(Rest4),
12181222
{FlagsValue, Rest6} = decode_literal(Rest5),
12191223
{MSt3, SrcReg} = MMod:move_to_native_register(MSt2, Src),
@@ -1338,7 +1342,7 @@ first_pass(<<?OP_BS_SKIP_BITS2, Rest0/binary>>, MMod, MSt0, State0) ->
13381342
?ASSERT_ALL_NATIVE_FREE(MSt0),
13391343
{Fail, Rest1} = decode_label(Rest0),
13401344
{MSt1, Src, Rest2} = decode_compact_term(Rest1, MMod, MSt0, State0),
1341-
{MSt2, Size, Rest3} = decode_compact_term(Rest2, MMod, MSt1, State0),
1345+
{MSt2, Size, Rest3} = decode_typed_compact_term(Rest2, MMod, MSt1, State0),
13421346
{Unit, Rest4} = decode_literal(Rest3),
13431347
{_FlagsValue, Rest5} = decode_literal(Rest4),
13441348
?TRACE("OP_BS_SKIP_BITS2 ~p, ~p, ~p, ~p, ~p\n", [Fail, Src, Size, Unit, _FlagsValue]),
@@ -2071,7 +2075,7 @@ first_pass(<<?OP_BS_GET_POSITION, Rest0/binary>>, MMod, MSt0, State0) ->
20712075
first_pass(<<?OP_BS_SET_POSITION, Rest0/binary>>, MMod, MSt0, State0) ->
20722076
?ASSERT_ALL_NATIVE_FREE(MSt0),
20732077
{MSt1, Src, Rest1} = decode_compact_term(Rest0, MMod, MSt0, State0),
2074-
{MSt2, Pos, Rest2} = decode_compact_term(Rest1, MMod, MSt1, State0),
2078+
{MSt2, Pos, Rest2} = decode_typed_compact_term(Rest1, MMod, MSt1, State0),
20752079
?TRACE("OP_BS_SET_POSITION ~p, ~p\n", [Src, Pos]),
20762080
{MSt3, MatchStateReg} = MMod:move_to_native_register(MSt2, Src),
20772081
{MSt4, MatchStateRegPtr} = verify_is_match_state_and_get_ptr(MMod, MSt3, {free, MatchStateReg}),
@@ -2814,7 +2818,7 @@ first_pass_bs_match_integer(
28142818
{_Live, Rest1} = decode_literal(Rest0),
28152819
{Flags, Rest2} = decode_compile_time_literal(Rest1, State0),
28162820
{MSt1, FlagsValue} = decode_flags_list(Flags, MMod, MSt0),
2817-
{MSt2, Size, Rest3} = decode_compact_term(Rest2, MMod, MSt0, State0),
2821+
{MSt2, Size, Rest3} = decode_typed_compact_term(Rest2, MMod, MSt0, State0),
28182822
{Unit, Rest4} = decode_literal(Rest3),
28192823
?TRACE("{integer,~p,~p,~p, ", [Flags, Size, Unit]),
28202824
{MSt3, SizeReg} = term_to_int(Size, 0, MMod, MSt1),
@@ -3185,6 +3189,14 @@ term_to_int(Term, _FailLabel, _MMod, MSt0) when is_integer(Term) ->
31853189
{MSt0, Term bsr 4};
31863190
term_to_int({literal, Val}, _FailLabel, _MMod, MSt0) when is_integer(Val) ->
31873191
{MSt0, Val};
3192+
% Optimized case: when we have type information showing this is an integer, skip the type check
3193+
term_to_int({typed, Term, {t_integer, _Range}}, _FailLabel, MMod, MSt0) ->
3194+
{MSt1, Reg} = MMod:move_to_native_register(MSt0, Term),
3195+
MSt2 = MMod:shift_right(MSt1, Reg, 4),
3196+
{MSt2, Reg};
3197+
term_to_int({typed, Term, _NonIntegerType}, FailLabel, MMod, MSt0) ->
3198+
% Type information shows it's not an integer, fall back to generic path
3199+
term_to_int(Term, FailLabel, MMod, MSt0);
31883200
term_to_int(Term, FailLabel, MMod, MSt0) ->
31893201
{MSt1, Reg} = MMod:move_to_native_register(MSt0, Term),
31903202
MSt2 = cond_raise_badarg_or_jump_to_fail_label(
@@ -3357,6 +3369,17 @@ decode_compact_term(<<_Value:5, ?COMPACT_LITERAL:3, _Rest/binary>> = Binary, _MM
33573369
decode_compact_term(Other, MMod, MSt, _State) ->
33583370
decode_dest(Other, MMod, MSt).
33593371

3372+
% Decode compact term with type information awareness
3373+
decode_typed_compact_term(<<?COMPACT_EXTENDED_TYPED_REGISTER, Rest0/binary>>, MMod, MSt0, #state{
3374+
type_resolver = TypeResover
3375+
}) ->
3376+
{MSt1, Dest, Rest1} = decode_dest(Rest0, MMod, MSt0),
3377+
{TypeIx, Rest2} = decode_literal(Rest1),
3378+
Type = TypeResover(TypeIx),
3379+
{MSt1, {typed, Dest, Type}, Rest2};
3380+
decode_typed_compact_term(Other, MMod, MSt, State) ->
3381+
decode_compact_term(Other, MMod, MSt, State).
3382+
33603383
skip_compact_term(<<_:4, ?COMPACT_INTEGER:4, _Rest/binary>> = Bin) ->
33613384
{_Value, Rest} = decode_value64(Bin),
33623385
Rest;

libs/jit/src/jit_precompile.erl

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
%
2020
-module(jit_precompile).
2121

22-
-export([start/0, compile/3]).
22+
-export([start/0, compile/3, atom_resolver/1, type_resolver/1]).
2323

2424
-include_lib("jit.hrl").
2525

@@ -36,8 +36,7 @@ compile(Target, Dir, Path) ->
3636
FilteredChunks = lists:keydelete("Code", 1, FilteredChunks0),
3737
{"Code", CodeChunk} = lists:keyfind("Code", 1, InitialChunks),
3838
{"AtU8", AtomChunk} = lists:keyfind("AtU8", 1, InitialChunks),
39-
Atoms = parse_atom_chunk(AtomChunk),
40-
AtomResolver = fun(Index) -> lists:nth(Index, Atoms) end,
39+
AtomResolver = atom_resolver(AtomChunk),
4140
LiteralsChunk =
4241
case lists:keyfind("LitU", 1, InitialChunks) of
4342
{"LitU", LiteralsChunk0} ->
@@ -52,8 +51,16 @@ compile(Target, Dir, Path) ->
5251
<<>>
5352
end
5453
end,
55-
Literals = parse_literals_chunk(LiteralsChunk),
56-
LiteralResolver = fun(Index) -> lists:nth(Index + 1, Literals) end,
54+
LiteralResolver = literal_resolver(LiteralsChunk),
55+
56+
TypesChunk =
57+
case lists:keyfind("Type", 1, InitialChunks) of
58+
{"Type", TypesChunk0} ->
59+
TypesChunk0;
60+
false ->
61+
<<>>
62+
end,
63+
TypeResolver = type_resolver(TypesChunk),
5764

5865
Stream0 = jit_stream_binary:new(0),
5966
<<16:32, 0:32, _OpcodeMax:32, LabelsCount:32, _FunctionsCount:32, _Opcodes/binary>> =
@@ -64,7 +71,7 @@ compile(Target, Dir, Path) ->
6471
Backend = list_to_atom("jit_" ++ Target),
6572
Stream2 = Backend:new(?JIT_VARIANT_PIC, jit_stream_binary, Stream1),
6673
{LabelsCount, Stream3} = jit:compile(
67-
CodeChunk, AtomResolver, LiteralResolver, Backend, Stream2
74+
CodeChunk, AtomResolver, LiteralResolver, TypeResolver, Backend, Stream2
6875
),
6976
NativeCode = Backend:stream(Stream3),
7077
UpdatedChunks = FilteredChunks ++ [{"avmN", NativeCode}],
@@ -78,6 +85,10 @@ compile(Target, Dir, Path) ->
7885
io:format("Unimplemented opcode ~p (~s)\n", [Opcode, Path])
7986
end.
8087

88+
atom_resolver(AtomChunk) ->
89+
Atoms = parse_atom_chunk(AtomChunk),
90+
fun(Index) -> lists:nth(Index, Atoms) end.
91+
8192
parse_atom_chunk(<<AtomCount:32/signed, Rest/binary>>) ->
8293
if
8394
AtomCount < 0 ->
@@ -100,6 +111,10 @@ parse_atom_chunk_old_format(<<Size, Atom:Size/binary, Tail/binary>>, Acc) ->
100111
parse_atom_chunk_old_format(<<>>, Acc) ->
101112
lists:reverse(Acc).
102113

114+
literal_resolver(LiteralsChunk) ->
115+
Literals = parse_literals_chunk(LiteralsChunk),
116+
fun(Index) -> lists:nth(Index + 1, Literals) end.
117+
103118
parse_literals_chunk(<<TermsCount:32, Rest/binary>>) ->
104119
parse_literals_chunk0(TermsCount, Rest, []);
105120
parse_literals_chunk(<<>>) ->
@@ -110,3 +125,82 @@ parse_literals_chunk0(0, <<>>, Acc) ->
110125
parse_literals_chunk0(N, <<TermSize:32, TermBin:TermSize/binary, Rest/binary>>, Acc) ->
111126
Term = binary_to_term(TermBin),
112127
parse_literals_chunk0(N - 1, Rest, [Term | Acc]).
128+
129+
%% Version (from beam_types.hrl)
130+
-define(BEAM_TYPES_VERSION, 3).
131+
132+
%% Type chunk constants (from beam_types.erl)
133+
-define(BEAM_TYPE_ATOM, (1 bsl 0)).
134+
-define(BEAM_TYPE_BITSTRING, (1 bsl 1)).
135+
-define(BEAM_TYPE_CONS, (1 bsl 2)).
136+
-define(BEAM_TYPE_FLOAT, (1 bsl 3)).
137+
-define(BEAM_TYPE_FUN, (1 bsl 4)).
138+
-define(BEAM_TYPE_INTEGER, (1 bsl 5)).
139+
-define(BEAM_TYPE_MAP, (1 bsl 6)).
140+
-define(BEAM_TYPE_NIL, (1 bsl 7)).
141+
-define(BEAM_TYPE_PID, (1 bsl 8)).
142+
-define(BEAM_TYPE_PORT, (1 bsl 9)).
143+
-define(BEAM_TYPE_REFERENCE, (1 bsl 10)).
144+
-define(BEAM_TYPE_TUPLE, (1 bsl 11)).
145+
146+
-define(BEAM_TYPE_HAS_LOWER_BOUND, (1 bsl 12)).
147+
-define(BEAM_TYPE_HAS_UPPER_BOUND, (1 bsl 13)).
148+
-define(BEAM_TYPE_HAS_UNIT, (1 bsl 14)).
149+
150+
type_resolver(<<Version:32, _Count:32, TypeData/binary>>) when Version =:= ?BEAM_TYPES_VERSION ->
151+
Types = parse_type_entries(TypeData, []),
152+
fun(Index) -> lists:nth(Index + 1, Types) end;
153+
type_resolver(_) ->
154+
fun(_) -> any end.
155+
156+
parse_type_entries(<<>>, Acc) ->
157+
lists:reverse(Acc);
158+
parse_type_entries(
159+
<<0:1, HasUnit:1, HasUpperBound:1, HasLowerBound:1, TypeBits:12, Rest0/binary>>, Acc
160+
) ->
161+
{Rest, LowerBound, UpperBound, Unit} = parse_extra(
162+
HasLowerBound, HasUpperBound, HasUnit, Rest0, '-inf', '+inf', 1
163+
),
164+
Type =
165+
case TypeBits of
166+
?BEAM_TYPE_ATOM ->
167+
t_atom;
168+
?BEAM_TYPE_BITSTRING ->
169+
{t_bs_matchable, Unit};
170+
?BEAM_TYPE_CONS ->
171+
t_cons;
172+
?BEAM_TYPE_FLOAT ->
173+
t_float;
174+
?BEAM_TYPE_FUN ->
175+
t_fun;
176+
?BEAM_TYPE_FLOAT bor ?BEAM_TYPE_INTEGER ->
177+
{t_number, {LowerBound, UpperBound}};
178+
?BEAM_TYPE_INTEGER ->
179+
{t_integer, {LowerBound, UpperBound}};
180+
?BEAM_TYPE_MAP ->
181+
t_map;
182+
?BEAM_TYPE_NIL ->
183+
nil;
184+
?BEAM_TYPE_NIL bor ?BEAM_TYPE_CONS ->
185+
t_list;
186+
?BEAM_TYPE_PID ->
187+
pid;
188+
?BEAM_TYPE_PORT ->
189+
port;
190+
?BEAM_TYPE_REFERENCE ->
191+
reference;
192+
?BEAM_TYPE_TUPLE ->
193+
t_tuple;
194+
_ ->
195+
any
196+
end,
197+
parse_type_entries(Rest, [Type | Acc]).
198+
199+
parse_extra(1, HasUpperBound, HasUnit, <<Value:64/signed, Rest/binary>>, '-inf', '+inf', 1) ->
200+
parse_extra(0, HasUpperBound, HasUnit, Rest, Value, '+inf', 1);
201+
parse_extra(0, 1, HasUnit, <<Value:64/signed, Rest/binary>>, LowerBound, '+inf', 1) ->
202+
parse_extra(0, 0, HasUnit, Rest, LowerBound, Value, 1);
203+
parse_extra(0, 0, 1, <<Value:8/unsigned, Rest/binary>>, LowerBound, UpperBound, 1) ->
204+
parse_extra(0, 0, 0, Rest, LowerBound, UpperBound, Value + 1);
205+
parse_extra(0, 0, 0, Rest, LowerBound, UpperBound, Unit) ->
206+
{Rest, LowerBound, UpperBound, Unit}.

src/libAtomVM/iff.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ void scan_iff(const void *iff_binary, int buf_size, unsigned long *offsets, unsi
100100
} else if (!memcmp(current_record->name, "avmN", 4)) {
101101
offsets[AVMN] = current_pos;
102102
sizes[AVMN] = ENDIAN_SWAP_32(current_record->size);
103+
} else if (!memcmp(current_record->name, "Type", 4)) {
104+
offsets[TYPE] = current_pos;
105+
sizes[TYPE] = ENDIAN_SWAP_32(current_record->size);
103106
}
104107

105108
current_pos += iff_align(ENDIAN_SWAP_32(current_record->size) + 8);

src/libAtomVM/iff.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@ extern "C" {
5858
#define LINT 9
5959
/** Native code section */
6060
#define AVMN 10
61+
/** Type table section */
62+
#define TYPE 11
6163

6264
/** Required size for offsets array */
63-
#define MAX_OFFS 11
65+
#define MAX_OFFS 12
6466
/** Required size for sizes array */
65-
#define MAX_SIZES 11
67+
#define MAX_SIZES 12
6668

6769
/** sizeof IFF section header in bytes */
6870
#define IFF_SECTION_HEADER_SIZE 8

0 commit comments

Comments
 (0)