Skip to content

Commit d3c7ddd

Browse files
committed
Fix logic in division by inflight_threads
1 parent 8a848cb commit d3c7ddd

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,17 @@
6262
// WebGPU implementation has bugs in handling concurrent operations. Serializing command submission
6363
// is a workaround, but we should also investigate better solutions.
6464
#ifdef GGML_WEBGPU_SERIALIZE_SUBMIT
65-
# define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 1
65+
# define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 1u
6666
# define WEBGPU_WAIT_ANY_TIMEOUT_MS UINT64_MAX
6767
#else
68-
# define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8
68+
# define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
6969
# define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
7070
#endif
7171

7272
/* Constants */
7373

7474
#define WEBGPU_MUL_MAT_WG_SIZE 256
75-
#define WEBGPU_NUM_PARAM_BUFS 32
75+
#define WEBGPU_NUM_PARAM_BUFS 32u
7676
// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
7777
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
7878
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
@@ -251,7 +251,7 @@ struct webgpu_context_struct {
251251
uint32_t max_wg_size_x;
252252

253253
std::recursive_mutex mutex;
254-
std::atomic_int inflight_threads = 0;
254+
std::atomic_uint inflight_threads = 0;
255255

256256
webgpu_buf_pool param_buf_pool;
257257
webgpu_buf_pool set_rows_error_buf_pool;
@@ -379,7 +379,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct
379379
uint64_t timeout_ms = UINT64_MAX) {
380380
// If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
381381
// inflight_max may be 0, meaning that we must wait on all futures.
382-
int inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::min(ctx->inflight_threads, 1);
382+
uint inflight_threads = ctx->inflight_threads;
383+
uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
383384
while (futures.size() >= inflight_max && futures.size() > 0) {
384385
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
385386
futures.erase(futures.begin());
@@ -1279,8 +1280,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
12791280
commands.push_back(*cmd);
12801281
}
12811282
// compute the batch size based on the number of inflight threads
1282-
int batch_size = std::min(std::max(1, WEBGPU_NUM_PARAM_BUFS / ctx->inflight_threads),
1283-
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1283+
uint inflight_threads = ctx->inflight_threads;
1284+
uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
1285+
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
12841286
if (commands.size() >= batch_size) {
12851287
futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
12861288
// Process events and check for completed submissions

0 commit comments

Comments
 (0)