@@ -1084,41 +1084,76 @@ struct FluxCLIPEmbedder : public Conditioner {
10841084 auto & t5_tokens = token_and_weights[1 ].first ;
10851085 auto & t5_weights = token_and_weights[1 ].second ;
10861086
1087- int64_t t0 = ggml_time_ms ();
1088- struct ggml_tensor * hidden_states = NULL ; // [N, n_token, 4096]
1089- struct ggml_tensor * chunk_hidden_states = NULL ; // [n_token, 4096]
1090- struct ggml_tensor * pooled = NULL ; // [768,]
1087+ int64_t t0 = ggml_time_ms ();
1088+ struct ggml_tensor * hidden_states = NULL ; // [N, n_token, 4096]
1089+ struct ggml_tensor * chunk_hidden_states = NULL ; // [n_token*2, 4096]
1090+ struct ggml_tensor * chunk_hidden_states_l = NULL ; // [n_token, hidden_size_l]
1091+ struct ggml_tensor * chunk_hidden_states_t5 = NULL ; // [n_token, hidden_size_t5]
1092+ struct ggml_tensor * pooled = NULL ; // [768,]
10911093 std::vector<float > hidden_states_vec;
10921094
1093- size_t chunk_len = 256 ;
1094- size_t chunk_count = t5_tokens .size () / chunk_len;
1095+ size_t chunk_len = 77 ;
1096+ size_t chunk_count = clip_l_tokens .size () / chunk_len;
10951097 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
10961098 // clip_l
1097- if (chunk_idx == 0 ) {
1098- size_t chunk_len_l = 77 ;
1099- std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1100- clip_l_tokens.begin () + chunk_len_l);
1101- std::vector<float > chunk_weights (clip_l_weights.begin (),
1102- clip_l_weights.begin () + chunk_len_l);
1099+ {
1100+ std::vector<int > chunk_tokens (clip_l_tokens.begin () + chunk_idx * chunk_len,
1101+ clip_l_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
1102+ std::vector<float > chunk_weights (clip_l_weights.begin () + chunk_idx * chunk_len,
1103+ clip_l_weights.begin () + (chunk_idx + 1 ) * chunk_len);
11031104
11041105 auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
11051106 size_t max_token_idx = 0 ;
11061107
1107- // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1108- // max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1109- // clip_l->compute(n_threads,
1110- // input_ids,
1111- // 0,
1112- // NULL,
1113- // max_token_idx,
1114- // true,
1115- // &pooled,
1116- // work_ctx);
1117-
1118- // clip_l.transformer.text_model.text_projection no in file, ignore
1119- // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
1120- pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
1121- ggml_set_f32 (pooled, 0 .f );
1108+ clip_l->compute (n_threads,
1109+ input_ids,
1110+ 0 ,
1111+ NULL ,
1112+ max_token_idx,
1113+ false ,
1114+ &chunk_hidden_states_l,
1115+ work_ctx);
1116+ {
1117+ auto tensor = chunk_hidden_states_l;
1118+ float original_mean = ggml_tensor_mean (tensor);
1119+ for (int i2 = 0 ; i2 < tensor->ne [2 ]; i2++) {
1120+ for (int i1 = 0 ; i1 < tensor->ne [1 ]; i1++) {
1121+ for (int i0 = 0 ; i0 < tensor->ne [0 ]; i0++) {
1122+ float value = ggml_tensor_get_f32 (tensor, i0, i1, i2);
1123+ value *= chunk_weights[i1];
1124+ ggml_tensor_set_f32 (tensor, value, i0, i1, i2);
1125+ }
1126+ }
1127+ }
1128+ float new_mean = ggml_tensor_mean (tensor);
1129+ ggml_tensor_scale (tensor, (original_mean / new_mean));
1130+ }
1131+ if (chunk_idx == 0 ) {
1132+ size_t chunk_len_l = 77 ;
1133+ std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1134+ clip_l_tokens.begin () + chunk_len_l);
1135+ std::vector<float > chunk_weights (clip_l_weights.begin (),
1136+ clip_l_weights.begin () + chunk_len_l);
1137+
1138+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1139+ size_t max_token_idx = 0 ;
1140+
1141+ // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1142+ // max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1143+ // clip_l->compute(n_threads,
1144+ // input_ids,
1145+ // 0,
1146+ // NULL,
1147+ // max_token_idx,
1148+ // true,
1149+ // &pooled,
1150+ // work_ctx);
1151+
1152+ // clip_l.transformer.text_model.text_projection no in file, ignore
1153+ // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
1154+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
1155+ ggml_set_f32 (pooled, 0 .f );
1156+ }
11221157 }
11231158
11241159 // t5
@@ -1132,10 +1167,10 @@ struct FluxCLIPEmbedder : public Conditioner {
11321167
11331168 t5->compute (n_threads,
11341169 input_ids,
1135- &chunk_hidden_states ,
1170+ &chunk_hidden_states_t5 ,
11361171 work_ctx);
11371172 {
1138- auto tensor = chunk_hidden_states ;
1173+ auto tensor = chunk_hidden_states_t5 ;
11391174 float original_mean = ggml_tensor_mean (tensor);
11401175 for (int i2 = 0 ; i2 < tensor->ne [2 ]; i2++) {
11411176 for (int i1 = 0 ; i1 < tensor->ne [1 ]; i1++) {
@@ -1151,6 +1186,29 @@ struct FluxCLIPEmbedder : public Conditioner {
11511186 }
11521187 }
11531188
1189+
1190+ // TODO: Maybe there's a better way to do the padding?
1191+ auto chunk_hidden_states_l_pad = ggml_new_tensor_3d (work_ctx,
1192+ chunk_hidden_states_l->type ,
1193+ 4096 ,
1194+ chunk_hidden_states_l->ne [1 ],
1195+ chunk_hidden_states_l->ne [2 ]); // [n_token, 4096]
1196+
1197+ for (int i2 = 0 ; i2 < chunk_hidden_states_l_pad->ne [2 ]; i2++) {
1198+ for (int i1 = 0 ; i1 < chunk_hidden_states_l_pad->ne [1 ]; i1++) {
1199+ for (int i0 = 0 ; i0 < chunk_hidden_states_l_pad->ne [0 ]; i0++) {
1200+ float value = 0 .f ;
1201+ if (i0 < chunk_hidden_states_l->ne [0 ]) {
1202+ value = ggml_tensor_get_f32 (chunk_hidden_states_l, i0, i1, i2);
1203+ }
1204+ ggml_tensor_set_f32 (chunk_hidden_states_l_pad, value, i0, i1, i2);
1205+ }
1206+ }
1207+ }
1208+
1209+ chunk_hidden_states = ggml_tensor_concat (work_ctx, chunk_hidden_states_l, chunk_hidden_states_t5, 1 ); // [n_token*2, 4096]
1210+
1211+
11541212 int64_t t1 = ggml_time_ms ();
11551213 LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
11561214 if (force_zero_embeddings) {
0 commit comments