Skip to content

Commit f32a874

Browse files
committed
resync and updated sdcpp for flux and sd3 support
1 parent 3372161 commit f32a874

30 files changed

+2434239
-1720
lines changed

otherarch/sdcpp/clip.hpp

Lines changed: 226 additions & 440 deletions
Large diffs are not rendered by default.

otherarch/sdcpp/common.hpp

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -279,26 +279,11 @@ class CrossAttention : public GGMLBlock {
279279
int64_t n_context = context->ne[1];
280280
int64_t inner_dim = d_head * n_head;
281281

282-
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
283-
q = ggml_reshape_4d(ctx, q, d_head, n_head, n_token, n); // [N, n_token, n_head, d_head]
284-
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
285-
q = ggml_reshape_3d(ctx, q, d_head, n_token, n_head * n); // [N * n_head, n_token, d_head]
282+
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
283+
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
284+
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
286285

287-
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
288-
k = ggml_reshape_4d(ctx, k, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
289-
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_context, d_head]
290-
k = ggml_reshape_3d(ctx, k, d_head, n_context, n_head * n); // [N * n_head, n_context, d_head]
291-
292-
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
293-
v = ggml_reshape_4d(ctx, v, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
294-
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, n_context]
295-
v = ggml_reshape_3d(ctx, v, n_context, d_head, n_head * n); // [N * n_head, d_head, n_context]
296-
297-
auto kqv = ggml_nn_attention(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
298-
kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, n);
299-
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
300-
301-
x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, n); // [N, n_token, inner_dim]
286+
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
302287

303288
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
304289
return x;
@@ -382,7 +367,7 @@ class SpatialTransformer : public GGMLBlock {
382367
int64_t n_head;
383368
int64_t d_head;
384369
int64_t depth = 1; // 1
385-
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_2_x
370+
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
386371

387372
public:
388373
SpatialTransformer(int64_t in_channels,

otherarch/sdcpp/conditioner.hpp

Lines changed: 1206 additions & 0 deletions
Large diffs are not rendered by default.

otherarch/sdcpp/control.hpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
*/
1515
class ControlNetBlock : public GGMLBlock {
1616
protected:
17-
SDVersion version = VERSION_1_x;
17+
SDVersion version = VERSION_SD1;
1818
// network hparams
1919
int in_channels = 4;
2020
int out_channels = 4;
@@ -26,19 +26,19 @@ class ControlNetBlock : public GGMLBlock {
2626
int time_embed_dim = 1280; // model_channels*4
2727
int num_heads = 8;
2828
int num_head_channels = -1; // channels // num_heads
29-
int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL
29+
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
3030

3131
public:
3232
int model_channels = 320;
33-
int adm_in_channels = 2816; // only for VERSION_XL
33+
int adm_in_channels = 2816; // only for VERSION_SDXL
3434

35-
ControlNetBlock(SDVersion version = VERSION_1_x)
35+
ControlNetBlock(SDVersion version = VERSION_SD1)
3636
: version(version) {
37-
if (version == VERSION_2_x) {
37+
if (version == VERSION_SD2) {
3838
context_dim = 1024;
3939
num_head_channels = 64;
4040
num_heads = -1;
41-
} else if (version == VERSION_XL) {
41+
} else if (version == VERSION_SDXL) {
4242
context_dim = 2048;
4343
attention_resolutions = {4, 2};
4444
channel_mult = {1, 2, 4};
@@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock {
5858
// time_embed_1 is nn.SiLU()
5959
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
6060

61-
if (version == VERSION_XL || version == VERSION_SVD) {
61+
if (version == VERSION_SDXL || version == VERSION_SVD) {
6262
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
6363
// label_emb_1 is nn.SiLU()
6464
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
@@ -306,8 +306,8 @@ class ControlNetBlock : public GGMLBlock {
306306
}
307307
};
308308

309-
struct ControlNet : public GGMLModule {
310-
SDVersion version = VERSION_1_x;
309+
struct ControlNet : public GGMLRunner {
310+
SDVersion version = VERSION_SD1;
311311
ControlNetBlock control_net;
312312

313313
ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory
@@ -318,8 +318,8 @@ struct ControlNet : public GGMLModule {
318318

319319
ControlNet(ggml_backend_t backend,
320320
ggml_type wtype,
321-
SDVersion version = VERSION_1_x)
322-
: GGMLModule(backend, wtype), control_net(version) {
321+
SDVersion version = VERSION_SD1)
322+
: GGMLRunner(backend, wtype), control_net(version) {
323323
control_net.init(params_ctx, wtype);
324324
}
325325

@@ -369,14 +369,6 @@ struct ControlNet : public GGMLModule {
369369
return "control_net";
370370
}
371371

372-
size_t get_params_mem_size() {
373-
return control_net.get_params_mem_size();
374-
}
375-
376-
size_t get_params_num() {
377-
return control_net.get_params_num();
378-
}
379-
380372
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
381373
control_net.get_param_tensors(tensors, prefix);
382374
}
@@ -434,7 +426,7 @@ struct ControlNet : public GGMLModule {
434426
return build_graph(x, hint, timesteps, context, y);
435427
};
436428

437-
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
429+
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
438430
guided_hint_cached = true;
439431
}
440432

0 commit comments

Comments
 (0)