Skip to content

Commit fc91520

Browse files
committed
Check for tensor size in supports_op
1 parent b0bd49f commit fc91520

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,18 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
12751275
}
12761276

12771277
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
1278-
GGML_UNUSED(dev);
1278+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1279+
1280+
webgpu_context webgpu_ctx = ctx->webgpu_ctx;
1281+
1282+
ggml_tensor * src0 = op->src[0];
1283+
ggml_tensor * src1 = op->src[1];
1284+
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
1285+
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
1286+
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
1287+
(src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
1288+
return false;
1289+
}
12791290

12801291
bool supports_op = false;
12811292
switch (op->op) {
@@ -1399,19 +1410,20 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
13991410
webgpu_context ctx = reg_ctx->webgpu_ctx;
14001411

14011412
wgpu::RequestAdapterOptions options = {};
1402-
ctx->instance.WaitAny(
1403-
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous,
1404-
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
1405-
if (status != wgpu::RequestAdapterStatus::Success) {
1406-
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
1407-
return;
1408-
}
1409-
ctx->adapter = std::move(adapter);
1410-
}), UINT64_MAX);
1413+
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
1414+
&options, wgpu::CallbackMode::AllowSpontaneous,
1415+
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
1416+
if (status != wgpu::RequestAdapterStatus::Success) {
1417+
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
1418+
return;
1419+
}
1420+
ctx->adapter = std::move(adapter);
1421+
}),
1422+
UINT64_MAX);
14111423
GGML_ASSERT(ctx->adapter != nullptr);
14121424

14131425
ctx->adapter.GetLimits(&ctx->limits);
1414-
ctx->max_wg_size_x = 256; // default value
1426+
ctx->max_wg_size_x = 256; // default value
14151427

14161428
wgpu::AdapterInfo info{};
14171429
ctx->adapter.GetInfo(&info);

0 commit comments

Comments
 (0)