Skip to content

Commit 7de8880

Browse files
committed
revert padding change for sd chroma
1 parent 1cf7648 commit 7de8880

File tree

2 files changed

+27
-30
lines changed

2 files changed

+27
-30
lines changed

otherarch/sdcpp/conditioner.hpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,21 +1288,6 @@ struct PixArtCLIPEmbedder : public Conditioner {
12881288
return {t5_tokens, t5_weights, t5_mask};
12891289
}
12901290

1291-
void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
1292-
float* mask_data = (float*)mask->data;
1293-
int num_pad = 0;
1294-
for (int64_t i = 0; i < max_seq_length; i++) {
1295-
if (num_pad >= num_extra_padding) {
1296-
break;
1297-
}
1298-
if (std::isinf(mask_data[i])) {
1299-
mask_data[i] = 0;
1300-
++num_pad;
1301-
}
1302-
}
1303-
// LOG_DEBUG("PAD: %d", num_pad);
1304-
}
1305-
13061291
SDCondition get_learned_condition_common(ggml_context* work_ctx,
13071292
int n_threads,
13081293
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
@@ -1389,21 +1374,6 @@ struct PixArtCLIPEmbedder : public Conditioner {
13891374
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
13901375
ggml_set_f32(hidden_states, 0.f);
13911376
}
1392-
1393-
int mask_pad = 1;
1394-
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1395-
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1396-
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1397-
try {
1398-
mask_pad = std::stoi(mask_pad_str);
1399-
} catch (const std::invalid_argument&) {
1400-
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1401-
} catch (const std::out_of_range&) {
1402-
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1403-
}
1404-
}
1405-
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
1406-
14071377
return SDCondition(hidden_states, t5_attn_mask, NULL);
14081378
}
14091379

otherarch/sdcpp/flux.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,20 @@ namespace Flux {
709709
return ids;
710710
}
711711

712+
void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
713+
float* mask_data = (float*)mask->data;
714+
int num_pad = 0;
715+
for (int64_t i = 0; i < max_seq_length; i++) {
716+
if (num_pad >= num_extra_padding) {
717+
break;
718+
}
719+
if (std::isinf(mask_data[i])) {
720+
mask_data[i] = 0;
721+
++num_pad;
722+
}
723+
}
724+
// LOG_DEBUG("PAD: %d", num_pad);
725+
}
712726

713727
// Generate positional embeddings
714728
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
@@ -1084,6 +1098,19 @@ namespace Flux {
10841098
guidance = ggml_set_f32(guidance, 0);
10851099
}
10861100

1101+
int mask_pad = 1;
1102+
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1103+
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1104+
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1105+
try {
1106+
mask_pad = std::stoi(mask_pad_str);
1107+
} catch (const std::invalid_argument&) {
1108+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1109+
} catch (const std::out_of_range&) {
1110+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1111+
}
1112+
}
1113+
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad);
10871114

10881115
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
10891116
if (SD_CHROMA_USE_DIT_MASK != nullptr) {

0 commit comments

Comments
 (0)