Skip to content

Commit 21303b6

Browse files
authored
mnist : use CMake to build mnist wasm example (#1269)
This commit updates the mnist examples to use CMake for building the WebAssembly (WASM) version of the MNIST example instead of the current emcc command. The motivation for this change is that using CMake should make it easier to maintin with regards to when changes in ggml occur they should not cause this example to break. Currently the emcc command is outdated and it was not clear how to updated it which is why this change was made. Resolves: #1264
1 parent 6a7d170 commit 21303b6

File tree

4 files changed

+91
-7
lines changed

4 files changed

+91
-7
lines changed

examples/mnist/CMakeLists.txt

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,41 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)
1818
set(TEST_TARGET mnist-train)
1919
add_executable(${TEST_TARGET} mnist-train.cpp)
2020
target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)
21+
22+
23+
#
24+
# mnist-wasm
25+
if (EMSCRIPTEN)
26+
set(TARGET mnist)
27+
28+
add_executable(${TARGET} mnist-common.cpp)
29+
target_link_libraries(${TARGET} PRIVATE ggml ggml-cpu)
30+
31+
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
32+
--bind \
33+
-s FORCE_FILESYSTEM=1 \
34+
-s USE_PTHREADS=1 \
35+
-s PTHREAD_POOL_SIZE=10 \
36+
-s ASSERTIONS=1 \
37+
-s WASM=1 \
38+
-s EXPORTED_RUNTIME_METHODS=\"['ccall', 'cwrap', 'setValue', 'getValue']\" \
39+
-s EXPORTED_FUNCTIONS=\"['_wasm_eval','_wasm_random_digit','_malloc','_free']\" \
40+
-s ALLOW_MEMORY_GROWTH=1 \
41+
--preload-file ${CMAKE_CURRENT_SOURCE_DIR}/mnist-f32.gguf@/ \
42+
--preload-file ${CMAKE_CURRENT_SOURCE_DIR}/t10k-images-idx3-ubyte@/ \
43+
")
44+
45+
# Copy output to web directory
46+
add_custom_command(
47+
TARGET ${TARGET} POST_BUILD
48+
COMMAND ${CMAKE_COMMAND} -E copy
49+
${CMAKE_BINARY_DIR}/bin/mnist.js
50+
${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.js
51+
COMMAND ${CMAKE_COMMAND} -E copy
52+
${CMAKE_BINARY_DIR}/bin/mnist.wasm
53+
${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.wasm
54+
COMMAND ${CMAKE_COMMAND} -E copy
55+
${CMAKE_BINARY_DIR}/bin/mnist.worker.js
56+
${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.worker.js
57+
)
58+
endif()

examples/mnist/README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,23 @@ Symlinking these files will *not* work!
178178
Compile the code like so:
179179

180180
```bash
181-
$ emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c ../../src/ggml-aarch64.c mnist-common.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file mnist-f32.gguf --preload-file t10k-images-idx3-ubyte
181+
$ cd ../../
182+
$ mkdir -p build-em
183+
$ emcmake cmake .. -DGGML_BUILD_EXAMPLES=ON \
184+
-DCMAKE_C_FLAGS="-pthread -matomics -mbulk-memory" \
185+
-DCMAKE_CXX_FLAGS="-pthread -matomics -mbulk-memory"
186+
$ make mnist
182187
```
183188

184-
The compilation output is in `examples/mnist/web`.
189+
The compilation output is copied into `examples/mnist/web`.
185190
To run it, you need an HTTP server.
186191
For example:
187192

188193
``` bash
189-
$ cd web
190-
$ python3 -m http.server
194+
$ python3 examples/mnist/server.py
191195

192-
Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
196+
Serving directory '/home/danbev/work/ai/ggml/examples/mnist/web' at http://localhost:8000
197+
Application context root: http://localhost:8000/
193198
```
194199

195200
The web demo can then be accessed via the link printed on the console.

examples/mnist/mnist-common.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str
227227
// The space in ctx_gguf exactly fits the model weights,
228228
// the images (which also need to be statically allocated) need to be put in a different context.
229229

230-
model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL);
230+
model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, nbatch_physical);
231+
231232
ggml_set_name(model.images, "images");
232233
ggml_set_input(model.images);
233234

@@ -458,7 +459,11 @@ int wasm_eval(uint8_t * digitPtr) {
458459

459460
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(GGML_TYPE_F32, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NCLASSES, 1, 1);
460461
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
461-
memcpy(data->data, digitPtr, ggml_nbytes(data));
462+
463+
float * buf = ggml_get_data_f32(data);
464+
for (int i = 0; i < MNIST_NINPUT; ++i) {
465+
buf[i] = digitPtr[i] / 255.0f;
466+
}
462467
ggml_set_zero(ggml_opt_dataset_labels(dataset)); // The labels are not needed.
463468

464469
mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU", /*nbatch_logical =*/ 1, /*nbatch_physical =*/ 1);

examples/mnist/server.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import http.server
2+
import socketserver
3+
import os
4+
import sys
5+
6+
DIRECTORY = os.path.abspath(os.path.join(os.path.dirname(__file__), 'web'))
7+
PORT = 8000
8+
9+
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, directory=DIRECTORY, **kwargs)
12+
13+
def end_headers(self):
14+
# Add required headers for SharedArrayBuffer
15+
self.send_header("Cross-Origin-Opener-Policy", "same-origin")
16+
self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
17+
self.send_header("Access-Control-Allow-Origin", "*")
18+
super().end_headers()
19+
20+
# Enable address reuse
21+
class CustomServer(socketserver.TCPServer):
22+
allow_reuse_address = True
23+
24+
try:
25+
with CustomServer(("", PORT), CustomHTTPRequestHandler) as httpd:
26+
print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}")
27+
print(f"Application context root: http://localhost:{PORT}/")
28+
try:
29+
httpd.serve_forever()
30+
except KeyboardInterrupt:
31+
print("\nServer stopped.")
32+
# Force complete exit
33+
sys.exit(0)
34+
except OSError as e:
35+
print(f"Error: {e}")
36+
sys.exit(1)

0 commit comments

Comments
 (0)