Skip to content

Commit bfb47cb

Browse files
committed
Revert "revert padding change for sd chroma"
This reverts commit 7de8880.
1 parent 5f9e96e commit bfb47cb

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

otherarch/sdcpp/conditioner.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,21 @@ struct PixArtCLIPEmbedder : public Conditioner {
12881288
return {t5_tokens, t5_weights, t5_mask};
12891289
}
12901290

1291+
void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
1292+
float* mask_data = (float*)mask->data;
1293+
int num_pad = 0;
1294+
for (int64_t i = 0; i < max_seq_length; i++) {
1295+
if (num_pad >= num_extra_padding) {
1296+
break;
1297+
}
1298+
if (std::isinf(mask_data[i])) {
1299+
mask_data[i] = 0;
1300+
++num_pad;
1301+
}
1302+
}
1303+
// LOG_DEBUG("PAD: %d", num_pad);
1304+
}
1305+
12911306
SDCondition get_learned_condition_common(ggml_context* work_ctx,
12921307
int n_threads,
12931308
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
@@ -1374,6 +1389,21 @@ struct PixArtCLIPEmbedder : public Conditioner {
13741389
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
13751390
ggml_set_f32(hidden_states, 0.f);
13761391
}
1392+
1393+
int mask_pad = 1;
1394+
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1395+
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1396+
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1397+
try {
1398+
mask_pad = std::stoi(mask_pad_str);
1399+
} catch (const std::invalid_argument&) {
1400+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1401+
} catch (const std::out_of_range&) {
1402+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1403+
}
1404+
}
1405+
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
1406+
13771407
return SDCondition(hidden_states, t5_attn_mask, NULL);
13781408
}
13791409

otherarch/sdcpp/flux.hpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -709,20 +709,6 @@ namespace Flux {
709709
return ids;
710710
}
711711

712-
void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
713-
float* mask_data = (float*)mask->data;
714-
int num_pad = 0;
715-
for (int64_t i = 0; i < max_seq_length; i++) {
716-
if (num_pad >= num_extra_padding) {
717-
break;
718-
}
719-
if (std::isinf(mask_data[i])) {
720-
mask_data[i] = 0;
721-
++num_pad;
722-
}
723-
}
724-
// LOG_DEBUG("PAD: %d", num_pad);
725-
}
726712

727713
// Generate positional embeddings
728714
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
@@ -1098,19 +1084,6 @@ namespace Flux {
10981084
guidance = ggml_set_f32(guidance, 0);
10991085
}
11001086

1101-
int mask_pad = 1;
1102-
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1103-
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1104-
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1105-
try {
1106-
mask_pad = std::stoi(mask_pad_str);
1107-
} catch (const std::invalid_argument&) {
1108-
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1109-
} catch (const std::out_of_range&) {
1110-
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1111-
}
1112-
}
1113-
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad);
11141087

11151088
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
11161089
if (SD_CHROMA_USE_DIT_MASK != nullptr) {

0 commit comments

Comments
 (0)