Skip to content

Commit e8a3ed2

Browse files
authored
feat(compiler): Deduplicate foreign imports (#2233)
1 parent a76df88 commit e8a3ed2

File tree

4 files changed

+113
-40
lines changed

4 files changed

+113
-40
lines changed

compiler/src/codegen/compcore.re

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ open Grain_utils;
88
open Comp_utils;
99
open Comp_wasm_prim;
1010

11+
module StringSet = Set.Make(String);
12+
1113
let sources: ref(list((Expression.t, Grain_parsing.Location.t))) = ref([]);
1214

1315
/** Environment */
@@ -22,6 +24,7 @@ type codegen_env = {
2224
/* Allocated closures which need backpatching */
2325
backpatches: ref(list((Expression.t, closure_data))),
2426
required_imports: list(import),
27+
foreign_import_resolutions: ref(StringSet.t),
2528
global_import_resolutions: Hashtbl.t(string, string),
2629
func_import_resolutions: Hashtbl.t(string, string),
2730
compilation_mode: Config.compilation_mode,
@@ -90,6 +93,7 @@ let init_codegen_env =
9093
},
9194
backpatches: ref([]),
9295
required_imports: [],
96+
foreign_import_resolutions: ref(StringSet.empty),
9397
global_import_resolutions,
9498
func_import_resolutions,
9599
compilation_mode: Normal,
@@ -2782,6 +2786,14 @@ and compile_instr = (wasm_mod, env, instr) =>
27822786
compiled_args,
27832787
Type.create(Array.of_list(List.map(wasm_type, retty))),
27842788
);
2789+
} else if (StringSet.mem(func_name, env.foreign_import_resolutions^)) {
2790+
// Deduplicated imports; call resolved name directly
2791+
Expression.Call.make(
2792+
wasm_mod,
2793+
resolved_name,
2794+
compiled_args,
2795+
Type.create(Array.of_list(List.map(wasm_type, retty))),
2796+
);
27852797
} else {
27862798
// Raw function resolved to Grain function; inject closure argument
27872799
let closure_global = resolve_global(~env, func_name);
@@ -3070,7 +3082,7 @@ let compute_table_size = (env, {function_table_elements}) => {
30703082
List.length(function_table_elements);
30713083
};
30723084

3073-
let compile_imports = (wasm_mod, env, {imports}) => {
3085+
let compile_imports = (wasm_mod, env, {imports}, import_map) => {
30743086
let compile_module_name = name =>
30753087
fun
30763088
| MImportWasm => name
@@ -3081,7 +3093,6 @@ let compile_imports = (wasm_mod, env, {imports}) => {
30813093
| (MImportGrain, MGlobalImport(_)) => "GRAIN$EXPORT$" ++ name
30823094
| _ => name
30833095
};
3084-
30853096
let compile_import = ({mimp_id, mimp_mod, mimp_name, mimp_type, mimp_kind}) => {
30863097
let module_name = compile_module_name(mimp_mod, mimp_kind);
30873098
let item_name = compile_import_name(mimp_name, mimp_kind, mimp_type);
@@ -3090,37 +3101,53 @@ let compile_imports = (wasm_mod, env, {imports}) => {
30903101
| MImportGrain => get_grain_imported_name(mimp_mod, mimp_id)
30913102
| MImportWasm => Ident.unique_name(mimp_id)
30923103
};
3093-
switch (mimp_kind, mimp_type) {
3094-
| (MImportGrain, MGlobalImport(ty, mut)) =>
3095-
Import.add_global_import(
3096-
wasm_mod,
3097-
internal_name,
3098-
module_name,
3099-
item_name,
3100-
wasm_type(ty),
3101-
mut,
3102-
)
3103-
| (_, MFuncImport(args, ret)) =>
3104-
let proc_list = l =>
3105-
Type.create @@ Array.of_list @@ List.map(wasm_type, l);
3106-
Import.add_function_import(
3107-
wasm_mod,
3108-
internal_name,
3109-
module_name,
3110-
item_name,
3111-
proc_list(args),
3112-
proc_list(ret),
3113-
);
3114-
| (_, MGlobalImport(typ, mut)) =>
3115-
let typ = wasm_type(typ);
3116-
Import.add_global_import(
3117-
wasm_mod,
3118-
internal_name,
3119-
module_name,
3120-
item_name,
3121-
typ,
3122-
mut,
3123-
);
3104+
let import_key = (module_name, item_name, mimp_kind, mimp_type);
3105+
switch (Hashtbl.find_opt(import_map, import_key)) {
3106+
| Some(name) when mimp_kind == MImportWasm =>
3107+
// Deduplicate wasm imports by resolving them to the previously imported name
3108+
let linked_name = linked_name(~env, internal_name);
3109+
switch (mimp_type) {
3110+
| MFuncImport(_, _) =>
3111+
Hashtbl.add(env.func_import_resolutions, linked_name, name)
3112+
| MGlobalImport(_, _) =>
3113+
Hashtbl.add(env.global_import_resolutions, linked_name, name)
3114+
};
3115+
env.foreign_import_resolutions :=
3116+
StringSet.add(linked_name, env.foreign_import_resolutions^);
3117+
| _ =>
3118+
Hashtbl.add(import_map, import_key, internal_name);
3119+
switch (mimp_kind, mimp_type) {
3120+
| (MImportGrain, MGlobalImport(ty, mut)) =>
3121+
Import.add_global_import(
3122+
wasm_mod,
3123+
internal_name,
3124+
module_name,
3125+
item_name,
3126+
wasm_type(ty),
3127+
mut,
3128+
)
3129+
| (_, MFuncImport(args, ret)) =>
3130+
let proc_list = l =>
3131+
Type.create @@ Array.of_list @@ List.map(wasm_type, l);
3132+
Import.add_function_import(
3133+
wasm_mod,
3134+
internal_name,
3135+
module_name,
3136+
item_name,
3137+
proc_list(args),
3138+
proc_list(ret),
3139+
);
3140+
| (_, MGlobalImport(typ, mut)) =>
3141+
let typ = wasm_type(typ);
3142+
Import.add_global_import(
3143+
wasm_mod,
3144+
internal_name,
3145+
module_name,
3146+
item_name,
3147+
typ,
3148+
mut,
3149+
);
3150+
};
31243151
};
31253152
};
31263153

@@ -3472,12 +3499,14 @@ let compile_wasm_module =
34723499
Type.funcref,
34733500
);
34743501

3502+
let import_map = Hashtbl.create(10);
3503+
34753504
let compile_one = (dep_id, prog: mash_code) => {
34763505
let env = {...env, dep_id, compilation_mode: prog.compilation_mode};
3506+
ignore @@ compile_imports(wasm_mod, env, prog, import_map);
34773507
ignore @@ compile_globals(wasm_mod, env, prog);
34783508
ignore @@ compile_functions(wasm_mod, env, prog);
34793509
ignore @@ compile_exports(wasm_mod, env, prog);
3480-
ignore @@ compile_imports(wasm_mod, env, prog);
34813510
ignore @@ compile_tables(wasm_mod, env, prog);
34823511
};
34833512

compiler/src/codegen/linkedtree.re

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,7 @@ let link = main_mashtree => {
109109
(resolved_module, import.mimp_name),
110110
);
111111
let import_name =
112-
Printf.sprintf(
113-
"%s_%d",
114-
Ident.unique_name(import.mimp_id),
115-
dep_id^,
116-
);
112+
internal_name(Ident.unique_name(import.mimp_id), dep_id^);
117113
Option.iter(
118114
global =>
119115
Hashtbl.add(global_import_resolutions, import_name, global),

compiler/test/suites/basic_functionality.re

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,6 @@ describe("basic functionality", ({test, testSkip}) => {
377377
~config_fn=smallestFileConfig,
378378
"smallest_grain_program",
379379
"",
380-
6507,
380+
6540,
381381
);
382382
});

compiler/test/suites/includes.re

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,52 @@ describe("includes", ({test, testSkip}) => {
207207
"from \"reprovideContents\" include ReprovideContents; use ReprovideContents.{ type OtherT as Other }; print({ x: 1 }: Other)",
208208
"{\n x: 1\n}\n",
209209
);
210+
/* Duplicate imports */
211+
test("dedupe_includes", ({expect}) => {
212+
let name = "dedupe_includes";
213+
let outfile = wasmfile(name);
214+
ignore @@
215+
compile(
216+
~hook=Grain.Compile.stop_after_assembled,
217+
name,
218+
{|
219+
module DeDupeIncludes
220+
// Ensures test is only included once
221+
foreign wasm test: WasmI32 => WasmI32 from "env"
222+
let test2 = test
223+
foreign wasm test: WasmI32 => WasmI32 from "env"
224+
@unsafe
225+
let _ = {
226+
test(1n)
227+
test2(1n)
228+
}
229+
|},
230+
);
231+
let ic = open_in_bin(outfile);
232+
let sections = Grain_utils.Wasm_utils.get_wasm_sections(ic);
233+
close_in(ic);
234+
let import_section =
235+
List.find_map(
236+
(sec: Grain_utils.Wasm_utils.wasm_bin_section) =>
237+
switch (sec) {
238+
| {sec_type: Import(imports)} => Some(imports)
239+
| _ => None
240+
},
241+
sections,
242+
);
243+
expect.option(import_section).toBeSome();
244+
expect.int(List.length(Option.get(import_section))).toBe(2);
245+
// Runtime printing import
246+
expect.list(Option.get(import_section)).toContainEqual((
247+
WasmFunction,
248+
"wasi_snapshot_preview1",
249+
"fd_write",
250+
));
251+
// Test import
252+
expect.list(Option.get(import_section)).toContainEqual((
253+
WasmFunction,
254+
"env",
255+
"test",
256+
));
257+
});
210258
});

0 commit comments

Comments
 (0)