Skip to content

Commit 7ce4f3b

Browse files
committed
conv direct as a flag
1 parent f5b5f5c commit 7ce4f3b

File tree

11 files changed

+161
-111
lines changed

11 files changed

+161
-111
lines changed

CMakeLists.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ option(SD_VULKAN "sd: vulkan backend" OFF)
3131
option(SD_OPENCL "sd: opencl backend" OFF)
3232
option(SD_SYCL "sd: sycl backend" OFF)
3333
option(SD_MUSA "sd: musa backend" OFF)
34-
option(SD_CONV2D_DIRECT "sd: enable conv2d direct support" OFF)
3534
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
3635
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
3736
#option(SD_BUILD_SERVER "sd: build server example" ON)
@@ -58,7 +57,6 @@ if (SD_OPENCL)
5857
message("-- Use OpenCL as backend stable-diffusion")
5958
set(GGML_OPENCL ON)
6059
add_definitions(-DSD_USE_OPENCL)
61-
add_definitions(-DSD_USE_CONV2D_DIRECT)
6260
endif ()
6361

6462
if (SD_HIPBLAS)
@@ -79,11 +77,6 @@ if(SD_MUSA)
7977
endif()
8078
endif()
8179

82-
if(SD_CONV2D_DIRECT)
83-
message("-- Use CONV2D Direct for VAE")
84-
add_definitions(-DSD_USE_CONV2D_DIRECT)
85-
endif ()
86-
8780
set(SD_LIB stable-diffusion)
8881

8982
file(GLOB SD_LIB_SOURCES

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ arguments:
339339
--vae-on-cpu keep vae in cpu (for low vram)
340340
--clip-on-cpu keep clip in cpu (for low vram)
341341
--diffusion-fa use flash attention in the diffusion model (for low vram)
342+
--diffusion-conv-direct use Conv2D direct in the diffusion model
343+
--vae-conv-direct use Conv2D direct in the vae model (should improve the performance)
342344
Might lower quality, since it implies converting k and v to f16.
343345
This might crash if it is not supported by the backend.
344346
--control-net-cpu keep controlnet in cpu (for low vram)

common.hpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@ class DownSampleBlock : public GGMLBlock {
88
int channels;
99
int out_channels;
1010
bool vae_downsample;
11+
bool direct = false;
1112

1213
public:
1314
DownSampleBlock(int channels,
1415
int out_channels,
15-
bool vae_downsample = false)
16+
bool vae_downsample = false,
17+
bool direct = false)
1618
: channels(channels),
1719
out_channels(out_channels),
18-
vae_downsample(vae_downsample) {
20+
vae_downsample(vae_downsample),
21+
direct(direct) {
1922
if (vae_downsample) {
20-
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}, {1, 1}, true, true));
23+
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}, {1, 1}, true, direct));
2124
} else {
22-
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}));
25+
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, true, direct));
2326
}
2427
}
2528

@@ -43,13 +46,16 @@ class UpSampleBlock : public GGMLBlock {
4346
protected:
4447
int channels;
4548
int out_channels;
49+
bool direct = false;
4650

4751
public:
4852
UpSampleBlock(int channels,
49-
int out_channels)
53+
int out_channels,
54+
bool direct = false)
5055
: channels(channels),
51-
out_channels(out_channels) {
52-
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
56+
out_channels(out_channels),
57+
direct(direct) {
58+
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
5359
}
5460

5561
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@@ -381,7 +387,8 @@ class SpatialTransformer : public GGMLBlock {
381387
int64_t d_head,
382388
int64_t depth,
383389
int64_t context_dim,
384-
bool flash_attn = false)
390+
bool flash_attn = false,
391+
bool direct = false)
385392
: in_channels(in_channels),
386393
n_head(n_head),
387394
d_head(d_head),
@@ -391,14 +398,14 @@ class SpatialTransformer : public GGMLBlock {
391398
// disable_self_attn is always False
392399
int64_t inner_dim = n_head * d_head; // in_channels
393400
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
394-
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
401+
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, direct));
395402

396403
for (int i = 0; i < depth; i++) {
397404
std::string name = "transformer_blocks." + std::to_string(i);
398405
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
399406
}
400407

401-
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
408+
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, direct));
402409
}
403410

404411
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {

diffusion_model.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ struct UNetModel : public DiffusionModel {
3434
UNetModel(ggml_backend_t backend,
3535
std::map<std::string, enum ggml_type>& tensor_types,
3636
SDVersion version = VERSION_SD1,
37-
bool flash_attn = false)
38-
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
37+
bool flash_attn = false,
38+
bool direct = false)
39+
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn, direct) {
3940
}
4041

4142
void alloc_params_buffer() {

examples/cli/main.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ struct SDParams {
9797
bool clip_on_cpu = false;
9898
bool vae_on_cpu = false;
9999
bool diffusion_flash_attn = false;
100+
bool diffusion_conv_direct = false;
101+
bool vae_conv_direct = false;
100102
bool canny_preprocess = false;
101103
bool color = false;
102104
int upscale_repeats = 1;
@@ -142,6 +144,8 @@ void print_params(SDParams params) {
142144
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
143145
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
144146
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
147+
printf(" diffusion Conv2D direct:%s\n", params.diffusion_conv_direct ? "true" : "false");
148+
printf(" vae Conv2D direct:%s\n", params.vae_conv_direct ? "true" : "false");
145149
printf(" strength(control): %.2f\n", params.control_strength);
146150
printf(" prompt: %s\n", params.prompt.c_str());
147151
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
@@ -232,6 +236,8 @@ void print_usage(int argc, const char* argv[]) {
232236
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
233237
printf(" Might lower quality, since it implies converting k and v to f16.\n");
234238
printf(" This might crash if it is not supported by the backend.\n");
239+
printf(" --diffusion-conv-direct use Conv2D direct in the diffusion model");
240+
printf(" --vae-conv-direct use Conv2D direct in the vae model (should improve the performance)");
235241
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
236242
printf(" --canny apply canny preprocessor (edge detection)\n");
237243
printf(" --color colors the logging tags according to level\n");
@@ -422,6 +428,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
422428
{"", "--clip-on-cpu", "", true, &params.clip_on_cpu},
423429
{"", "--vae-on-cpu", "", true, &params.vae_on_cpu},
424430
{"", "--diffusion-fa", "", true, &params.diffusion_flash_attn},
431+
{"", "--diffusion-conv-direct", "", true, &params.diffusion_conv_direct},
432+
{"", "--vae-conv-direct", "", true, &params.vae_conv_direct},
425433
{"", "--canny", "", true, &params.canny_preprocess},
426434
{"-v", "--verbos", "", true, &params.verbose},
427435
{"", "--color", "", true, &params.color},
@@ -901,6 +909,8 @@ int main(int argc, const char* argv[]) {
901909
params.control_net_cpu,
902910
params.vae_on_cpu,
903911
params.diffusion_flash_attn,
912+
params.diffusion_conv_direct,
913+
params.vae_conv_direct,
904914
params.chroma_use_dit_mask,
905915
params.chroma_use_t5_mask,
906916
params.chroma_t5_mask_pad,

ggml_extend.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,14 +1514,10 @@ class Conv2d : public UnaryBlock {
15141514
direct = true
15151515
#endif
15161516
if (direct) {
1517-
#if defined(SD_USE_CONV2D_DIRECT)
1518-
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) || defined(SD_USE_METAL)
1519-
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1520-
#else
1521-
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1522-
#endif
1523-
#else
1517+
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) || defined(SD_USE_METAL)
15241518
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1519+
#else
1520+
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
15251521
#endif
15261522
} else {
15271523
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);

stable-diffusion.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,12 @@ class StableDiffusionGGML {
326326
LOG_INFO("CLIP: Using CPU backend");
327327
clip_backend = ggml_backend_cpu_init();
328328
}
329+
if (sd_ctx_params->diffusion_conv_direct) {
330+
LOG_INFO("Using Conv2D direct in the diffusion model");
331+
}
332+
if (sd_ctx_params->vae_conv_direct){
333+
LOG_INFO("Using Conv2D direct in the vae model");
334+
}
329335
if (sd_ctx_params->diffusion_flash_attn) {
330336
LOG_INFO("Using flash attention in the diffusion model");
331337
}
@@ -373,7 +379,8 @@ class StableDiffusionGGML {
373379
diffusion_model = std::make_shared<UNetModel>(backend,
374380
model_loader.tensor_storages_types,
375381
version,
376-
sd_ctx_params->diffusion_flash_attn);
382+
sd_ctx_params->diffusion_flash_attn,
383+
sd_ctx_params->diffusion_conv_direct);
377384
}
378385

379386
cond_stage_model->alloc_params_buffer();
@@ -394,15 +401,17 @@ class StableDiffusionGGML {
394401
"first_stage_model",
395402
vae_decode_only,
396403
false,
397-
version);
404+
version,
405+
sd_ctx_params->vae_conv_direct);
398406
first_stage_model->alloc_params_buffer();
399407
first_stage_model->get_param_tensors(tensors, "first_stage_model");
400408
} else {
401409
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend,
402410
model_loader.tensor_storages_types,
403411
"decoder.layers",
404412
vae_decode_only,
405-
version);
413+
version,
414+
sd_ctx_params->vae_conv_direct);
406415
}
407416
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
408417

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ typedef struct {
134134
bool keep_control_net_on_cpu;
135135
bool keep_vae_on_cpu;
136136
bool diffusion_flash_attn;
137+
bool diffusion_conv_direct;
138+
bool vae_conv_direct;
137139
bool chroma_use_dit_mask;
138140
bool chroma_use_t5_mask;
139141
int chroma_t5_mask_pad;

0 commit comments

Comments
 (0)