Skip to content

Commit 836fd72

Browse files
committed
Chroma: Attention masking (no pad)
1 parent f506a63 commit 836fd72

File tree

4 files changed

+266
-38
lines changed

4 files changed

+266
-38
lines changed

conditioner.hpp

Lines changed: 180 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ struct SD3CLIPEmbedder : public Conditioner {
747747

748748
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
749749
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
750-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
750+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
751751

752752
// for (int i = 0; i < clip_l_tokens.size(); i++) {
753753
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1077,7 +1077,7 @@ struct FluxCLIPEmbedder : public Conditioner {
10771077
}
10781078

10791079
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1080-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
1080+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
10811081

10821082
// for (int i = 0; i < clip_l_tokens.size(); i++) {
10831083
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1218,4 +1218,182 @@ struct FluxCLIPEmbedder : public Conditioner {
12181218
}
12191219
};
12201220

1221+
struct PixArtCLIPEmbedder : public Conditioner {
1222+
T5UniGramTokenizer t5_tokenizer;
1223+
std::shared_ptr<T5Runner> t5;
1224+
1225+
PixArtCLIPEmbedder(ggml_backend_t backend,
1226+
std::map<std::string, enum ggml_type>& tensor_types,
1227+
int clip_skip = -1) {
1228+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1229+
}
1230+
1231+
void set_clip_skip(int clip_skip) {
1232+
}
1233+
1234+
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1235+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1236+
}
1237+
1238+
void alloc_params_buffer() {
1239+
t5->alloc_params_buffer();
1240+
}
1241+
1242+
void free_params_buffer() {
1243+
t5->free_params_buffer();
1244+
}
1245+
1246+
size_t get_params_buffer_size() {
1247+
size_t buffer_size = 0;
1248+
1249+
buffer_size += t5->get_params_buffer_size();
1250+
1251+
return buffer_size;
1252+
}
1253+
1254+
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
1255+
size_t max_length = 0,
1256+
bool padding = false) {
1257+
auto parsed_attention = parse_prompt_attention(text);
1258+
1259+
{
1260+
std::stringstream ss;
1261+
ss << "[";
1262+
for (const auto& item : parsed_attention) {
1263+
ss << "['" << item.first << "', " << item.second << "], ";
1264+
}
1265+
ss << "]";
1266+
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
1267+
}
1268+
1269+
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
1270+
return false;
1271+
};
1272+
1273+
std::vector<int> t5_tokens;
1274+
std::vector<float> t5_weights;
1275+
std::vector<float> t5_mask;
1276+
for (const auto& item : parsed_attention) {
1277+
const std::string& curr_text = item.first;
1278+
float curr_weight = item.second;
1279+
1280+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1281+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1282+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1283+
}
1284+
1285+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1286+
1287+
return {t5_tokens, t5_weights, t5_mask};
1288+
}
1289+
1290+
SDCondition get_learned_condition_common(ggml_context* work_ctx,
1291+
int n_threads,
1292+
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
1293+
int clip_skip,
1294+
bool force_zero_embeddings = false) {
1295+
auto& t5_tokens = std::get<0>(token_and_weights);
1296+
auto& t5_weights = std::get<1>(token_and_weights);
1297+
auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
1298+
1299+
int64_t t0 = ggml_time_ms();
1300+
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
1301+
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
1302+
struct ggml_tensor* pooled = NULL; // [768,]
1303+
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,]
1304+
1305+
std::vector<float> hidden_states_vec;
1306+
1307+
size_t chunk_len = 256;
1308+
size_t chunk_count = t5_tokens.size() / chunk_len;
1309+
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
1310+
// t5
1311+
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1312+
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1313+
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1314+
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1315+
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1316+
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1317+
1318+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1319+
auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask);
1320+
1321+
t5->compute(n_threads,
1322+
input_ids,
1323+
&chunk_hidden_states,
1324+
work_ctx,
1325+
t5_attn_mask_chunk);
1326+
{
1327+
auto tensor = chunk_hidden_states;
1328+
float original_mean = ggml_tensor_mean(tensor);
1329+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1330+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1331+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1332+
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1333+
value *= chunk_weights[i1];
1334+
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1335+
}
1336+
}
1337+
}
1338+
float new_mean = ggml_tensor_mean(tensor);
1339+
ggml_tensor_scale(tensor, (original_mean / new_mean));
1340+
}
1341+
1342+
int64_t t1 = ggml_time_ms();
1343+
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
1344+
if (force_zero_embeddings) {
1345+
float* vec = (float*)chunk_hidden_states->data;
1346+
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
1347+
vec[i] = 0;
1348+
}
1349+
}
1350+
1351+
hidden_states_vec.insert(hidden_states_vec.end(),
1352+
(float*)chunk_hidden_states->data,
1353+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1354+
}
1355+
1356+
if (hidden_states_vec.size() > 0) {
1357+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1358+
hidden_states = ggml_reshape_2d(work_ctx,
1359+
hidden_states,
1360+
chunk_hidden_states->ne[0],
1361+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1362+
} else {
1363+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1364+
ggml_set_f32(hidden_states, 0.f);
1365+
}
1366+
return SDCondition(hidden_states, t5_attn_mask, NULL);
1367+
}
1368+
1369+
SDCondition get_learned_condition(ggml_context* work_ctx,
1370+
int n_threads,
1371+
const std::string& text,
1372+
int clip_skip,
1373+
int width,
1374+
int height,
1375+
int adm_in_channels = -1,
1376+
bool force_zero_embeddings = false) {
1377+
auto tokens_and_weights = tokenize(text, 512, true);
1378+
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1379+
}
1380+
1381+
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
1382+
int n_threads,
1383+
const std::string& text,
1384+
int clip_skip,
1385+
int width,
1386+
int height,
1387+
int num_input_imgs,
1388+
int adm_in_channels = -1,
1389+
bool force_zero_embeddings = false) {
1390+
GGML_ASSERT(0 && "Not implemented yet!");
1391+
}
1392+
1393+
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
1394+
const std::string& prompt) {
1395+
GGML_ASSERT(0 && "Not implemented yet!");
1396+
}
1397+
};
1398+
12211399
#endif

0 commit comments

Comments
 (0)