Skip to content

Commit f18220e

Browse files
committed
Experimental WebAssembly multithreading support
1 parent 2b9a482 commit f18220e

File tree

17 files changed

+547
-29
lines changed

17 files changed

+547
-29
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/aoc_wasm/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ rust-version = { workspace = true }
99

1010
[dependencies]
1111
aoc = { path = "../aoc" }
12+
utils = { path = "../utils", optional = true }
13+
14+
[features]
15+
multithreading = ["utils/wasm-multithreading"]
1216

1317
[lints]
1418
workspace = true

crates/aoc_wasm/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
//! Simple WebAssembly interface without external libraries.
22
33
mod custom_sections;
4+
#[cfg(feature = "multithreading")]
5+
mod multithreading;
46

57
use aoc::all_puzzles;
68
use aoc::utils::input::InputType;
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use aoc::utils::wasm::scoped_tasks::worker;
2+
use std::alloc::{alloc_zeroed, Layout};
3+
4+
/// Allocate stack for worker threads.
5+
///
6+
/// **WARNING**: Stack overflows on worker threads will corrupt other parts of the linear memory.
7+
#[no_mangle]
8+
extern "C" fn allocate_stack(size: usize, align: usize) -> *mut u8 {
9+
let layout = Layout::from_size_align(size, align).unwrap();
10+
unsafe { alloc_zeroed(layout) }
11+
}
12+
13+
/// Run worker thread.
14+
#[no_mangle]
15+
extern "C" fn worker_thread() {
16+
worker();
17+
}

crates/aoc_wasm/web/aoc.mjs

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
* @property {WebAssembly.Global} PART1
77
* @property {WebAssembly.Global} PART2
88
* @property {number} PUZZLES
9+
* @property {WebAssembly.Global} [__tls_size] If multithreaded
10+
* @property {WebAssembly.Global} [__tls_align] If multithreaded
11+
* @property {WebAssembly.Global} [__tls_base] If multithreaded
12+
* @property {WebAssembly.Global} [__stack_pointer] If multithreaded
13+
* @property {(size: number, align: number) => number} [allocate_stack] If multithreaded
14+
* @property {() => void} [worker_thread] If multithreaded
915
*/
1016

1117
/**
@@ -22,12 +28,16 @@
2228
const BUFFER_SIZE = 1024 * 1024;
2329

2430
export class Aoc {
31+
/** @type {boolean} */
32+
#multithreaded;
2533
/** @type {WebAssembly.Module} */
2634
#module;
2735
/** @type {WebAssembly.Instance} */
2836
#instance;
29-
/** @type {Puzzles} */
30-
#puzzles;
37+
/** @type {WebAssembly.Memory} */
38+
#memory;
39+
/** @type {Worker[]} */
40+
#workers;
3141

3242
/**
3343
* @param {WebAssembly.Module} module
@@ -90,14 +100,66 @@ export class Aoc {
90100
* @param {WebAssembly.Instance} [instance]
91101
*/
92102
constructor(module, instance) {
93-
this.#module = module;
94-
this.#instance = instance ?? new WebAssembly.Instance(module);
103+
const imports = WebAssembly.Module.imports(module);
104+
if (imports.length === 0) {
105+
this.#multithreaded = false;
106+
this.#module = module;
107+
this.#instance = instance ?? new WebAssembly.Instance(module);
108+
this.#memory = this.#exports.memory;
109+
} else if (imports.length === 1 && imports[0].module === "env" && imports[0].name === "memory" && imports[0].kind === "memory") {
110+
this.#multithreaded = true;
111+
this.#module = module;
112+
if (instance) throw new Error("Instance cannot be provided for multithreaded modules");
113+
this.newInstance();
114+
} else {
115+
throw new Error("Unsupported module");
116+
}
95117
}
96118

97-
/** @return {Puzzles} */
98-
get puzzles() {
99-
this.#puzzles ??= Aoc.puzzleList(this.#module);
100-
return this.#puzzles;
119+
/** @param {number} [numWorkers] */
120+
newInstance(numWorkers) {
121+
if (this.#multithreaded) {
122+
if (this.#workers?.length > 0) {
123+
// Stop existing workers
124+
for (const worker of this.#workers) {
125+
worker.terminate();
126+
}
127+
numWorkers ??= this.#workers.length;
128+
this.#workers = [];
129+
}
130+
numWorkers ??= navigator.hardwareConcurrency;
131+
132+
this.#memory = new WebAssembly.Memory({initial: 96, maximum: 2048, shared: true});
133+
this.#instance = new WebAssembly.Instance(this.#module, {env: {memory: this.#memory}});
134+
135+
// Stack alignment must be at least 16 bytes.
136+
//
137+
// Only aligning the stack to 8 bytes (this.#exports.__tls_align.value at the time of writing) causes 2016
138+
// day 14 to inconsistently return wrong answers in release builds as the optimizer uses `i32.or` instead of
139+
// `i32.add` when adding on small array indexes.
140+
let align = Math.max(16, this.#exports.__tls_align.value);
141+
let tlsSize = Math.ceil(this.#exports.__tls_size.value / align) * align;
142+
let stackSize = Math.ceil(this.#exports.__stack_pointer.value / align) * align;
143+
144+
// Use a single allocation for stack & tls, using the first stackSize bytes for the stack and the remaining
145+
// tlsSize bytes for thread local storage. This makes __tls_base and __stack_pointer the same value (as
146+
// the stack grows downwards and TLS is above __tls_base), similar to the main thread.
147+
//
148+
// Allocate all the stacks at once to avoid memory growing as workers start, which seems to cause problems.
149+
const stacks = [];
150+
for (let i = 0; i < numWorkers; i++) {
151+
stacks.push(this.#exports.allocate_stack(stackSize + tlsSize, align));
152+
}
153+
154+
this.#workers = [];
155+
for (let i = 0; i < numWorkers; i++) {
156+
const worker = new Worker("./worker.mjs", {type: "module"});
157+
worker.postMessage(["thread", this.#module, this.#memory, stacks[i] + stackSize]);
158+
this.#workers.push(worker);
159+
}
160+
} else {
161+
this.#instance = new WebAssembly.Instance(this.#module);
162+
}
101163
}
102164

103165
/**
@@ -115,7 +177,7 @@ export class Aoc {
115177
this.#write(input);
116178
success = this.#exports.run_puzzle(year, day, isExample, part1, part2);
117179
} catch (e) {
118-
this.#instance = new WebAssembly.Instance(this.#module);
180+
this.newInstance();
119181
return {
120182
success: false,
121183
error: "Unexpected error: " + e.toString() + (e.stack ? "\n\n" + e.stack : ""),
@@ -147,18 +209,29 @@ export class Aoc {
147209
*/
148210
#buffer(type) {
149211
const address = this.#exports[type].value;
150-
return new Uint8Array(this.#exports.memory.buffer)
212+
return new Uint8Array(this.#memory.buffer)
151213
.subarray(address, address + BUFFER_SIZE);
152214
}
153215

154216
/** @param {string} input */
155217
#write(input) {
156218
const buffer = this.#buffer("INPUT");
157-
const result = new TextEncoder().encodeInto(input, buffer);
158-
if (result.read < input.length || result.written === buffer.length) {
159-
throw new Error("Input string is too long");
219+
if (this.#multithreaded) {
220+
// Can't encode directly into SharedArrayBuffer
221+
const temp = new Uint8Array(BUFFER_SIZE);
222+
const result = new TextEncoder().encodeInto(input, temp);
223+
if (result.read < input.length || result.written === buffer.length) {
224+
throw new Error("Input string is too long");
225+
}
226+
buffer.set(temp.subarray(0, result.written));
227+
buffer[result.written] = 0;
228+
} else {
229+
const result = new TextEncoder().encodeInto(input, buffer);
230+
if (result.read < input.length || result.written === buffer.length) {
231+
throw new Error("Input string is too long");
232+
}
233+
buffer[result.written] = 0;
160234
}
161-
buffer[result.written] = 0;
162235
}
163236

164237
/**
@@ -173,6 +246,13 @@ export class Aoc {
173246
buffer = buffer.subarray(0, end);
174247
}
175248

249+
if (this.#multithreaded) {
250+
// Can't decode directly from SharedArrayBuffer
251+
const temp = new Uint8Array(buffer.length);
252+
temp.set(buffer);
253+
buffer = temp;
254+
}
255+
176256
return (new TextDecoder()).decode(buffer);
177257
}
178258
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// This is a hack to set Cross-Origin-Embedder-Policy/Cross-Origin-Opener-Policy (required to send SharedArrayBuffer
2+
// instances across workers/use multithreaded WebAssembly) on GitHub pages
3+
4+
if (typeof window === "undefined") {
5+
self.addEventListener("install", () => self.skipWaiting());
6+
self.addEventListener("activate", event => event.waitUntil(self.clients.claim()));
7+
8+
self.addEventListener("fetch", event => {
9+
event.respondWith(fetch(event.request).then(response => {
10+
if (response.status === 0) return response;
11+
12+
const headers = new Headers(response.headers);
13+
headers.set('Cross-Origin-Embedder-Policy', 'require-corp');
14+
headers.set('Cross-Origin-Opener-Policy', 'same-origin');
15+
16+
return new Response(response.body, {status: response.status, statusText: response.statusText, headers});
17+
}));
18+
});
19+
} else if (window.crossOriginIsolated && !navigator.serviceWorker?.controller) {
20+
console.log("cross-origin isolated without service worker"); // No service worker workaround needed
21+
} else {
22+
if (window.crossOriginIsolated) {
23+
console.log("cross-origin isolated using service worker"); // Still re-register
24+
} else {
25+
console.log("not cross-origin isolated, trying service worker");
26+
}
27+
28+
navigator.serviceWorker.register(document.currentScript.src).then(registration => {
29+
console.log(`service worker registered with scope ${registration.scope}`);
30+
31+
registration.addEventListener("updatefound", () => {
32+
console.log("reloading page due to service worker update");
33+
window.location.reload();
34+
});
35+
}, error => {
36+
console.error("error registering service worker", error);
37+
});
38+
}

crates/aoc_wasm/web/index.html

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@
6868
}
6969
}
7070
</style>
71+
<script src="cross-origin-isolation-service-worker.js" defer></script>
7172
<script src="./web.mjs" type="module"></script>
7273
<link rel="modulepreload" href="./aoc.mjs" />
7374
<link rel="modulepreload" href="./worker.mjs" />
74-
<link rel="preload" href="./aoc-simd128.wasm" as="fetch" crossorigin />
75+
<link rel="preload" href="./aoc-threads.wasm" as="fetch" crossorigin />
7576
</head>
7677
<body class="is-flex is-flex-direction-column">
7778
<nav class="navbar" role="navigation" aria-label="main navigation">

crates/aoc_wasm/web/web.mjs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ const MODULE_PATHS = [
44
"./aoc-simd128.wasm",
55
"./aoc.wasm",
66
];
7+
if (window.crossOriginIsolated) {
8+
MODULE_PATHS.unshift("./aoc-threads.wasm");
9+
}
710

811
let module;
912
for (const path of MODULE_PATHS) {

crates/aoc_wasm/web/worker.mjs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,12 @@ onmessage = (e) => {
1515
postMessage(result);
1616
console.log(result);
1717
break;
18+
case "thread":
19+
const [module, memory, ptr] = e.data;
20+
instance = new WebAssembly.Instance(module, {env: {memory}});
21+
instance.exports.__stack_pointer.value = ptr; // Stack uses storage below the provided pointer
22+
instance.exports.__wasm_init_tls(ptr); // TLS uses storage above the provided pointer
23+
instance.exports.worker_thread();
24+
throw new Error("unreachable");
1825
}
1926
};

crates/utils/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ rust-version = { workspace = true }
1111

1212
[features]
1313
unsafe = []
14+
wasm-multithreading = ["unsafe"]
1415

1516
[lints]
1617
workspace = true

0 commit comments

Comments
 (0)