@@ -672,11 +672,11 @@ namespace Flux {
672672 }
673673
674674 // Generate IDs for image patches and text
675- std::vector<std::vector<float >> gen_ids (int h, int w, int patch_size, int bs, int context_len ) {
675+ std::vector<std::vector<float >> gen_ids (int h, int w, int patch_size, int index = 0 ) {
676676 int h_len = (h + (patch_size / 2 )) / patch_size;
677677 int w_len = (w + (patch_size / 2 )) / patch_size;
678678
679- std::vector<std::vector<float >> img_ids (h_len * w_len, std::vector<float >(3 , 0.0 ));
679+ std::vector<std::vector<float >> img_ids (h_len * w_len, std::vector<float >(3 , ( float )index ));
680680
681681 std::vector<float > row_ids = linspace (0 , h_len - 1 , h_len);
682682 std::vector<float > col_ids = linspace (0 , w_len - 1 , w_len);
@@ -688,10 +688,22 @@ namespace Flux {
688688 }
689689 }
690690
691- std::vector<std::vector<float >> img_ids_repeated (bs * img_ids.size (), std::vector<float >(3 ));
692- for (int i = 0 ; i < bs; ++i) {
693- for (int j = 0 ; j < img_ids.size (); ++j) {
694- img_ids_repeated[i * img_ids.size () + j] = img_ids[j];
691+ return img_ids;
692+ }
693+
694+ // Generate positional embeddings
695+ std::vector<float > gen_pe (std::vector<struct ggml_tensor *> imgs, struct ggml_tensor * context, int patch_size, int theta, const std::vector<int >& axes_dim) {
696+ int context_len = context->ne [1 ];
697+ int bs = imgs[0 ]->ne [3 ];
698+
699+ std::vector<std::vector<float >> img_ids;
700+ for (int i = 0 ; i < imgs.size (); i++) {
701+ auto x = imgs[i];
702+ if (x) {
703+ int h = x->ne [1 ];
704+ int w = x->ne [0 ];
705+ std::vector<std::vector<float >> img_ids_i = gen_ids (h, w, patch_size, i);
706+ img_ids.insert (img_ids.end (), img_ids_i.begin (), img_ids_i.end ());
695707 }
696708 }
697709
@@ -702,17 +714,10 @@ namespace Flux {
702714 ids[i * (context_len + img_ids.size ()) + j] = txt_ids[j];
703715 }
704716 for (int j = 0 ; j < img_ids.size (); ++j) {
705- ids[i * (context_len + img_ids.size ()) + context_len + j] = img_ids_repeated[i * img_ids. size () + j];
717+ ids[i * (context_len + img_ids.size ()) + context_len + j] = img_ids[ j];
706718 }
707719 }
708720
709- return ids;
710- }
711-
712-
713- // Generate positional embeddings
714- std::vector<float > gen_pe (int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int >& axes_dim) {
715- std::vector<std::vector<float >> ids = gen_ids (h, w, patch_size, bs, context_len);
716721 std::vector<std::vector<float >> trans_ids = transpose (ids);
717722 size_t pos_len = ids.size ();
718723 int num_axes = axes_dim.size ();
@@ -925,15 +930,16 @@ namespace Flux {
925930 }
926931
927932 struct ggml_tensor * forward (struct ggml_context * ctx,
928- struct ggml_tensor * x ,
933+ std::vector< struct ggml_tensor *> imgs ,
929934 struct ggml_tensor * timestep,
930935 struct ggml_tensor * context,
931936 struct ggml_tensor * c_concat,
932937 struct ggml_tensor * y,
933938 struct ggml_tensor * guidance,
934939 struct ggml_tensor * pe,
935940 struct ggml_tensor * arange = NULL ,
936- std::vector<int > skip_layers = std::vector<int >()) {
941+ std::vector<int > skip_layers = std::vector<int >(),
942+ SDVersion version = VERSION_FLUX) {
937943 // Forward pass of DiT.
938944 // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
939945 // timestep: (N,) tensor of diffusion timesteps
@@ -944,18 +950,31 @@ namespace Flux {
944950 // pe: (L, d_head/2, 2, 2)
945951 // return: (N, C, H, W)
946952
953+ auto x = imgs[0 ];
947954 GGML_ASSERT (x->ne [3 ] == 1 );
948955
949956 int64_t W = x->ne [0 ];
950957 int64_t H = x->ne [1 ];
951958 int64_t C = x->ne [2 ];
952959 int64_t patch_size = 2 ;
953- int pad_h = (patch_size - H % patch_size) % patch_size;
954- int pad_w = (patch_size - W % patch_size) % patch_size;
955- x = ggml_pad (ctx, x, pad_w, pad_h, 0 , 0 ); // [N, C, H + pad_h, W + pad_w]
960+ int pad_h = (patch_size - x->ne [0 ] % patch_size) % patch_size;
961+ int pad_w = (patch_size - x->ne [1 ] % patch_size) % patch_size;
956962
957963 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
958- auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
964+ ggml_tensor* img = NULL ; // [N, h*w, C * patch_size * patch_size]
965+ int64_t patchified_img_size;
966+ for (auto & x : imgs) {
967+ int pad_h = (patch_size - x->ne [0 ] % patch_size) % patch_size;
968+ int pad_w = (patch_size - x->ne [1 ] % patch_size) % patch_size;
969+ ggml_tensor* pad_x = ggml_pad (ctx, x, pad_w, pad_h, 0 , 0 );
970+ pad_x = patchify (ctx, pad_x, patch_size);
971+ if (img) {
972+ img = ggml_concat (ctx, img, pad_x, 1 );
973+ } else {
974+ img = pad_x;
975+ patchified_img_size = img->ne [1 ];
976+ }
977+ }
959978
960979 if (c_concat != NULL ) {
961980 ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
@@ -971,6 +990,7 @@ namespace Flux {
971990 }
972991
973992 auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
993+ out = ggml_cont (ctx, ggml_view_2d (ctx, out, out->ne [0 ], patchified_img_size, out->nb [1 ], 0 ));
974994
975995 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
976996 out = unpatchify (ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
@@ -1056,7 +1076,8 @@ namespace Flux {
10561076 struct ggml_tensor * c_concat,
10571077 struct ggml_tensor * y,
10581078 struct ggml_tensor * guidance,
1059- std::vector<int > skip_layers = std::vector<int >()) {
1079+ std::vector<struct ggml_tensor *> kontext_imgs = std::vector<struct ggml_tensor *>(),
1080+ std::vector<int> skip_layers = std::vector<int>()) {
10601081 GGML_ASSERT (x->ne [3 ] == 1 );
10611082 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
10621083
@@ -1067,6 +1088,9 @@ namespace Flux {
10671088 if (c_concat != NULL ) {
10681089 c_concat = to_backend (c_concat);
10691090 }
1091+ for (auto &img : kontext_imgs){
1092+ img = to_backend (img);
1093+ }
10701094 if (flux_params.is_chroma ) {
10711095 const char * SD_CHROMA_ENABLE_GUIDANCE = getenv (" SD_CHROMA_ENABLE_GUIDANCE" );
10721096 bool disable_guidance = true ;
@@ -1107,8 +1131,10 @@ namespace Flux {
11071131 if (flux_params.guidance_embed || flux_params.is_chroma ) {
11081132 guidance = to_backend (guidance);
11091133 }
1134+ auto imgs = kontext_imgs;
1135+ imgs.insert (imgs.begin (), x);
11101136
1111- pe_vec = flux.gen_pe (x-> ne [ 1 ], x-> ne [ 0 ] , 2 , x-> ne [ 3 ], context-> ne [ 1 ] , flux_params.theta , flux_params.axes_dim );
1137+ pe_vec = flux.gen_pe (imgs, context , 2 , flux_params.theta , flux_params.axes_dim );
11121138 int pos_len = pe_vec.size () / flux_params.axes_dim_sum / 2 ;
11131139 // LOG_DEBUG("pos_len %d", pos_len);
11141140 auto pe = ggml_new_tensor_4d (compute_ctx, GGML_TYPE_F32, 2 , 2 , flux_params.axes_dim_sum / 2 , pos_len);
@@ -1118,7 +1144,7 @@ namespace Flux {
11181144 set_backend_tensor_data (pe, pe_vec.data ());
11191145
11201146 struct ggml_tensor * out = flux.forward (compute_ctx,
1121- x ,
1147+ imgs ,
11221148 timesteps,
11231149 context,
11241150 c_concat,
@@ -1140,16 +1166,17 @@ namespace Flux {
11401166 struct ggml_tensor * c_concat,
11411167 struct ggml_tensor * y,
11421168 struct ggml_tensor * guidance,
1143- struct ggml_tensor ** output = NULL ,
1144- struct ggml_context * output_ctx = NULL ,
1145- std::vector<int > skip_layers = std::vector<int >()) {
1169+ std::vector<struct ggml_tensor *> kontext_imgs = std::vector<struct ggml_tensor *>(),
1170+ struct ggml_tensor** output = NULL,
1171+ struct ggml_context* output_ctx = NULL,
1172+ std::vector<int> skip_layers = std::vector<int>()) {
11461173 // x: [N, in_channels, h, w]
11471174 // timesteps: [N, ]
11481175 // context: [N, max_position, hidden_size]
11491176 // y: [N, adm_in_channels] or [1, adm_in_channels]
11501177 // guidance: [N, ]
11511178 auto get_graph = [&]() -> struct ggml_cgraph * {
1152- return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
1179+ return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_imgs, skip_layers);
11531180 };
11541181
11551182 GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
@@ -1189,7 +1216,7 @@ namespace Flux {
11891216 struct ggml_tensor * out = NULL ;
11901217
11911218 int t0 = ggml_time_ms ();
1192- compute (8 , x, timesteps, context, NULL , y, guidance, &out, work_ctx);
1219+ compute (8 , x, timesteps, context, NULL , y, guidance, std::vector< struct ggml_tensor *>(), &out, work_ctx);
11931220 int t1 = ggml_time_ms ();
11941221
11951222 print_ggml_tensor (out);
0 commit comments