Skip to content

Commit ff4976e

Browse files
committed
Flux: clip_l support
1 parent ac54e00 commit ff4976e

File tree

1 file changed

+87
-29
lines changed

1 file changed

+87
-29
lines changed

conditioner.hpp

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,41 +1084,76 @@ struct FluxCLIPEmbedder : public Conditioner {
10841084
auto& t5_tokens = token_and_weights[1].first;
10851085
auto& t5_weights = token_and_weights[1].second;
10861086

1087-
int64_t t0 = ggml_time_ms();
1088-
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
1089-
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
1090-
struct ggml_tensor* pooled = NULL; // [768,]
1087+
int64_t t0 = ggml_time_ms();
1088+
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
1089+
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token*2, 4096]
1090+
struct ggml_tensor* chunk_hidden_states_l = NULL; // [n_token, hidden_size_l]
1091+
struct ggml_tensor* chunk_hidden_states_t5 = NULL; // [n_token, hidden_size_t5]
1092+
struct ggml_tensor* pooled = NULL; // [768,]
10911093
std::vector<float> hidden_states_vec;
10921094

1093-
size_t chunk_len = 256;
1094-
size_t chunk_count = t5_tokens.size() / chunk_len;
1095+
size_t chunk_len = 77;
1096+
size_t chunk_count = clip_l_tokens.size() / chunk_len;
10951097
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
10961098
// clip_l
1097-
if (chunk_idx == 0) {
1098-
size_t chunk_len_l = 77;
1099-
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1100-
clip_l_tokens.begin() + chunk_len_l);
1101-
std::vector<float> chunk_weights(clip_l_weights.begin(),
1102-
clip_l_weights.begin() + chunk_len_l);
1099+
{
1100+
std::vector<int> chunk_tokens(clip_l_tokens.begin() + chunk_idx * chunk_len,
1101+
clip_l_tokens.begin() + (chunk_idx + 1) * chunk_len);
1102+
std::vector<float> chunk_weights(clip_l_weights.begin() + chunk_idx * chunk_len,
1103+
clip_l_weights.begin() + (chunk_idx + 1) * chunk_len);
11031104

11041105
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
11051106
size_t max_token_idx = 0;
11061107

1107-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1108-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1109-
// clip_l->compute(n_threads,
1110-
// input_ids,
1111-
// 0,
1112-
// NULL,
1113-
// max_token_idx,
1114-
// true,
1115-
// &pooled,
1116-
// work_ctx);
1117-
1118-
// clip_l.transformer.text_model.text_projection no in file, ignore
1119-
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
1120-
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1121-
ggml_set_f32(pooled, 0.f);
1108+
clip_l->compute(n_threads,
1109+
input_ids,
1110+
0,
1111+
NULL,
1112+
max_token_idx,
1113+
false,
1114+
&chunk_hidden_states_l,
1115+
work_ctx);
1116+
{
1117+
auto tensor = chunk_hidden_states_l;
1118+
float original_mean = ggml_tensor_mean(tensor);
1119+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1120+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1121+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1122+
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1123+
value *= chunk_weights[i1];
1124+
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1125+
}
1126+
}
1127+
}
1128+
float new_mean = ggml_tensor_mean(tensor);
1129+
ggml_tensor_scale(tensor, (original_mean / new_mean));
1130+
}
1131+
if (chunk_idx == 0) {
1132+
size_t chunk_len_l = 77;
1133+
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1134+
clip_l_tokens.begin() + chunk_len_l);
1135+
std::vector<float> chunk_weights(clip_l_weights.begin(),
1136+
clip_l_weights.begin() + chunk_len_l);
1137+
1138+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1139+
size_t max_token_idx = 0;
1140+
1141+
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1142+
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1143+
// clip_l->compute(n_threads,
1144+
// input_ids,
1145+
// 0,
1146+
// NULL,
1147+
// max_token_idx,
1148+
// true,
1149+
// &pooled,
1150+
// work_ctx);
1151+
1152+
// clip_l.transformer.text_model.text_projection no in file, ignore
1153+
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
1154+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1155+
ggml_set_f32(pooled, 0.f);
1156+
}
11221157
}
11231158

11241159
// t5
@@ -1132,10 +1167,10 @@ struct FluxCLIPEmbedder : public Conditioner {
11321167

11331168
t5->compute(n_threads,
11341169
input_ids,
1135-
&chunk_hidden_states,
1170+
&chunk_hidden_states_t5,
11361171
work_ctx);
11371172
{
1138-
auto tensor = chunk_hidden_states;
1173+
auto tensor = chunk_hidden_states_t5;
11391174
float original_mean = ggml_tensor_mean(tensor);
11401175
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
11411176
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
@@ -1151,6 +1186,29 @@ struct FluxCLIPEmbedder : public Conditioner {
11511186
}
11521187
}
11531188

1189+
1190+
// TODO: Maybe there's a better way to do the padding?
1191+
auto chunk_hidden_states_l_pad = ggml_new_tensor_3d(work_ctx,
1192+
chunk_hidden_states_l->type,
1193+
4096,
1194+
chunk_hidden_states_l->ne[1],
1195+
chunk_hidden_states_l->ne[2]); // [n_token, 4096]
1196+
1197+
for (int i2 = 0; i2 < chunk_hidden_states_l_pad->ne[2]; i2++) {
1198+
for (int i1 = 0; i1 < chunk_hidden_states_l_pad->ne[1]; i1++) {
1199+
for (int i0 = 0; i0 < chunk_hidden_states_l_pad->ne[0]; i0++) {
1200+
float value = 0.f;
1201+
if (i0 < chunk_hidden_states_l->ne[0]) {
1202+
value = ggml_tensor_get_f32(chunk_hidden_states_l, i0, i1, i2);
1203+
}
1204+
ggml_tensor_set_f32(chunk_hidden_states_l_pad, value, i0, i1, i2);
1205+
}
1206+
}
1207+
}
1208+
1209+
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states_l, chunk_hidden_states_t5, 1); // [n_token*2, 4096]
1210+
1211+
11541212
int64_t t1 = ggml_time_ms();
11551213
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
11561214
if (force_zero_embeddings) {

0 commit comments

Comments
 (0)