Skip to content

Commit 9bbf53c

Browse files
committed
Align conv2d behavior to flash attention's
1 parent 70cef96 commit 9bbf53c

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

examples/cli/main.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ void print_usage(int argc, const char* argv[]) {
237237
printf(" Might lower quality, since it implies converting k and v to f16.\n");
238238
printf(" This might crash if it is not supported by the backend.\n");
239239
printf(" --diffusion-conv-direct use Conv2D direct in the diffusion model");
240+
printf(" This might crash if it is not supported by the backend.\n");
240241
printf(" --vae-conv-direct use Conv2D direct in the vae model (should improve the performance)");
242+
printf(" This might crash if it is not supported by the backend.\n");
241243
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
242244
printf(" --canny apply canny preprocessor (edge detection)\n");
243245
printf(" --color colors the logging tags according to level\n");

ggml_extend.hpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,15 +1518,8 @@ class Conv2d : public UnaryBlock {
15181518
if (bias) {
15191519
b = params["bias"];
15201520
}
1521-
#if defined(SD_USE_OPENCL)
1522-
direct = true
1523-
#endif
15241521
if (direct) {
1525-
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) || defined(SD_USE_METAL)
1526-
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1527-
#else
1528-
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1529-
#endif
1522+
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
15301523
} else {
15311524
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
15321525
}

0 commit comments

Comments
 (0)