Skip to content

Commit ed28922

Browse files
committed
added support for flux kontext
1 parent 0bd648f commit ed28922

File tree

6 files changed

+202
-70
lines changed

6 files changed

+202
-70
lines changed

otherarch/sdcpp/diffusion_model.hpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ struct DiffusionModel {
1313
struct ggml_tensor* c_concat,
1414
struct ggml_tensor* y,
1515
struct ggml_tensor* guidance,
16-
int num_video_frames = -1,
17-
std::vector<struct ggml_tensor*> controls = {},
18-
float control_strength = 0.f,
19-
struct ggml_tensor** output = NULL,
20-
struct ggml_context* output_ctx = NULL,
21-
std::vector<int> skip_layers = std::vector<int>()) = 0;
16+
int num_video_frames = -1,
17+
std::vector<struct ggml_tensor*> controls = {},
18+
float control_strength = 0.f,
19+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
20+
struct ggml_tensor** output = NULL,
21+
struct ggml_context* output_ctx = NULL,
22+
std::vector<int> skip_layers = std::vector<int>()) = 0;
2223
virtual void alloc_params_buffer() = 0;
2324
virtual void free_params_buffer() = 0;
2425
virtual void free_compute_buffer() = 0;
@@ -68,12 +69,13 @@ struct UNetModel : public DiffusionModel {
6869
struct ggml_tensor* c_concat,
6970
struct ggml_tensor* y,
7071
struct ggml_tensor* guidance,
71-
int num_video_frames = -1,
72-
std::vector<struct ggml_tensor*> controls = {},
73-
float control_strength = 0.f,
74-
struct ggml_tensor** output = NULL,
75-
struct ggml_context* output_ctx = NULL,
76-
std::vector<int> skip_layers = std::vector<int>()) {
72+
int num_video_frames = -1,
73+
std::vector<struct ggml_tensor*> controls = {},
74+
float control_strength = 0.f,
75+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
76+
struct ggml_tensor** output = NULL,
77+
struct ggml_context* output_ctx = NULL,
78+
std::vector<int> skip_layers = std::vector<int>()) {
7779
(void)skip_layers; // SLG doesn't work with UNet models
7880
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
7981
}
@@ -118,12 +120,13 @@ struct MMDiTModel : public DiffusionModel {
118120
struct ggml_tensor* c_concat,
119121
struct ggml_tensor* y,
120122
struct ggml_tensor* guidance,
121-
int num_video_frames = -1,
122-
std::vector<struct ggml_tensor*> controls = {},
123-
float control_strength = 0.f,
124-
struct ggml_tensor** output = NULL,
125-
struct ggml_context* output_ctx = NULL,
126-
std::vector<int> skip_layers = std::vector<int>()) {
123+
int num_video_frames = -1,
124+
std::vector<struct ggml_tensor*> controls = {},
125+
float control_strength = 0.f,
126+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
127+
struct ggml_tensor** output = NULL,
128+
struct ggml_context* output_ctx = NULL,
129+
std::vector<int> skip_layers = std::vector<int>()) {
127130
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
128131
}
129132
};
@@ -169,13 +172,14 @@ struct FluxModel : public DiffusionModel {
169172
struct ggml_tensor* c_concat,
170173
struct ggml_tensor* y,
171174
struct ggml_tensor* guidance,
172-
int num_video_frames = -1,
173-
std::vector<struct ggml_tensor*> controls = {},
174-
float control_strength = 0.f,
175-
struct ggml_tensor** output = NULL,
176-
struct ggml_context* output_ctx = NULL,
177-
std::vector<int> skip_layers = std::vector<int>()) {
178-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
175+
int num_video_frames = -1,
176+
std::vector<struct ggml_tensor*> controls = {},
177+
float control_strength = 0.f,
178+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
179+
struct ggml_tensor** output = NULL,
180+
struct ggml_context* output_ctx = NULL,
181+
std::vector<int> skip_layers = std::vector<int>()) {
182+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_imgs, output, output_ctx, skip_layers);
179183
}
180184
};
181185

otherarch/sdcpp/flux.hpp

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,11 @@ namespace Flux {
672672
}
673673

674674
// Generate IDs for image patches and text
675-
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int bs, int context_len) {
675+
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int index = 0) {
676676
int h_len = (h + (patch_size / 2)) / patch_size;
677677
int w_len = (w + (patch_size / 2)) / patch_size;
678678

679-
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
679+
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, (float)index));
680680

681681
std::vector<float> row_ids = linspace(0, h_len - 1, h_len);
682682
std::vector<float> col_ids = linspace(0, w_len - 1, w_len);
@@ -688,10 +688,22 @@ namespace Flux {
688688
}
689689
}
690690

691-
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
692-
for (int i = 0; i < bs; ++i) {
693-
for (int j = 0; j < img_ids.size(); ++j) {
694-
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
691+
return img_ids;
692+
}
693+
694+
// Generate positional embeddings
695+
std::vector<float> gen_pe(std::vector<struct ggml_tensor*> imgs, struct ggml_tensor* context, int patch_size, int theta, const std::vector<int>& axes_dim) {
696+
int context_len = context->ne[1];
697+
int bs = imgs[0]->ne[3];
698+
699+
std::vector<std::vector<float>> img_ids;
700+
for (int i = 0; i < imgs.size(); i++) {
701+
auto x = imgs[i];
702+
if (x) {
703+
int h = x->ne[1];
704+
int w = x->ne[0];
705+
std::vector<std::vector<float>> img_ids_i = gen_ids(h, w, patch_size, i);
706+
img_ids.insert(img_ids.end(), img_ids_i.begin(), img_ids_i.end());
695707
}
696708
}
697709

@@ -702,17 +714,10 @@ namespace Flux {
702714
ids[i * (context_len + img_ids.size()) + j] = txt_ids[j];
703715
}
704716
for (int j = 0; j < img_ids.size(); ++j) {
705-
ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j];
717+
ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids[j];
706718
}
707719
}
708720

709-
return ids;
710-
}
711-
712-
713-
// Generate positional embeddings
714-
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
715-
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len);
716721
std::vector<std::vector<float>> trans_ids = transpose(ids);
717722
size_t pos_len = ids.size();
718723
int num_axes = axes_dim.size();
@@ -925,15 +930,16 @@ namespace Flux {
925930
}
926931

927932
struct ggml_tensor* forward(struct ggml_context* ctx,
928-
struct ggml_tensor* x,
933+
std::vector<struct ggml_tensor*> imgs,
929934
struct ggml_tensor* timestep,
930935
struct ggml_tensor* context,
931936
struct ggml_tensor* c_concat,
932937
struct ggml_tensor* y,
933938
struct ggml_tensor* guidance,
934939
struct ggml_tensor* pe,
935940
struct ggml_tensor* arange = NULL,
936-
std::vector<int> skip_layers = std::vector<int>()) {
941+
std::vector<int> skip_layers = std::vector<int>(),
942+
SDVersion version = VERSION_FLUX) {
937943
// Forward pass of DiT.
938944
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
939945
// timestep: (N,) tensor of diffusion timesteps
@@ -944,18 +950,31 @@ namespace Flux {
944950
// pe: (L, d_head/2, 2, 2)
945951
// return: (N, C, H, W)
946952

953+
auto x = imgs[0];
947954
GGML_ASSERT(x->ne[3] == 1);
948955

949956
int64_t W = x->ne[0];
950957
int64_t H = x->ne[1];
951958
int64_t C = x->ne[2];
952959
int64_t patch_size = 2;
953-
int pad_h = (patch_size - H % patch_size) % patch_size;
954-
int pad_w = (patch_size - W % patch_size) % patch_size;
955-
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
960+
int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size;
961+
int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size;
956962

957963
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
958-
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
964+
ggml_tensor* img = NULL; // [N, h*w, C * patch_size * patch_size]
965+
int64_t patchified_img_size;
966+
for (auto& x : imgs) {
967+
int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size;
968+
int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size;
969+
ggml_tensor* pad_x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0);
970+
pad_x = patchify(ctx, pad_x, patch_size);
971+
if (img) {
972+
img = ggml_concat(ctx, img, pad_x, 1);
973+
} else {
974+
img = pad_x;
975+
patchified_img_size = img->ne[1];
976+
}
977+
}
959978

960979
if (c_concat != NULL) {
961980
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
@@ -971,6 +990,7 @@ namespace Flux {
971990
}
972991

973992
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
993+
out = ggml_cont(ctx, ggml_view_2d(ctx, out, out->ne[0], patchified_img_size, out->nb[1], 0));
974994

975995
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
976996
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
@@ -1056,7 +1076,8 @@ namespace Flux {
10561076
struct ggml_tensor* c_concat,
10571077
struct ggml_tensor* y,
10581078
struct ggml_tensor* guidance,
1059-
std::vector<int> skip_layers = std::vector<int>()) {
1079+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
1080+
std::vector<int> skip_layers = std::vector<int>()) {
10601081
GGML_ASSERT(x->ne[3] == 1);
10611082
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
10621083

@@ -1067,6 +1088,9 @@ namespace Flux {
10671088
if (c_concat != NULL) {
10681089
c_concat = to_backend(c_concat);
10691090
}
1091+
for (auto &img : kontext_imgs){
1092+
img = to_backend(img);
1093+
}
10701094
if (flux_params.is_chroma) {
10711095
const char* SD_CHROMA_ENABLE_GUIDANCE = getenv("SD_CHROMA_ENABLE_GUIDANCE");
10721096
bool disable_guidance = true;
@@ -1107,8 +1131,10 @@ namespace Flux {
11071131
if (flux_params.guidance_embed || flux_params.is_chroma) {
11081132
guidance = to_backend(guidance);
11091133
}
1134+
auto imgs = kontext_imgs;
1135+
imgs.insert(imgs.begin(), x);
11101136

1111-
pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim);
1137+
pe_vec = flux.gen_pe(imgs, context, 2, flux_params.theta, flux_params.axes_dim);
11121138
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
11131139
// LOG_DEBUG("pos_len %d", pos_len);
11141140
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
@@ -1118,7 +1144,7 @@ namespace Flux {
11181144
set_backend_tensor_data(pe, pe_vec.data());
11191145

11201146
struct ggml_tensor* out = flux.forward(compute_ctx,
1121-
x,
1147+
imgs,
11221148
timesteps,
11231149
context,
11241150
c_concat,
@@ -1140,16 +1166,17 @@ namespace Flux {
11401166
struct ggml_tensor* c_concat,
11411167
struct ggml_tensor* y,
11421168
struct ggml_tensor* guidance,
1143-
struct ggml_tensor** output = NULL,
1144-
struct ggml_context* output_ctx = NULL,
1145-
std::vector<int> skip_layers = std::vector<int>()) {
1169+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
1170+
struct ggml_tensor** output = NULL,
1171+
struct ggml_context* output_ctx = NULL,
1172+
std::vector<int> skip_layers = std::vector<int>()) {
11461173
// x: [N, in_channels, h, w]
11471174
// timesteps: [N, ]
11481175
// context: [N, max_position, hidden_size]
11491176
// y: [N, adm_in_channels] or [1, adm_in_channels]
11501177
// guidance: [N, ]
11511178
auto get_graph = [&]() -> struct ggml_cgraph* {
1152-
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
1179+
return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_imgs, skip_layers);
11531180
};
11541181

11551182
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@@ -1189,7 +1216,7 @@ namespace Flux {
11891216
struct ggml_tensor* out = NULL;
11901217

11911218
int t0 = ggml_time_ms();
1192-
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
1219+
compute(8, x, timesteps, context, NULL, y, guidance, std::vector<struct ggml_tensor*>(), &out, work_ctx);
11931220
int t1 = ggml_time_ms();
11941221

11951222
print_ggml_tensor(out);

0 commit comments

Comments
 (0)