Skip to content

Commit 67cc996

Browse files
committed
implement chroma mask padding
1 parent 836fd72 commit 67cc996

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

flux.hpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,18 @@ namespace Flux {
711711
}
712712

713713
void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
714-
// TODO: implement
714+
float* mask_data = (float*)mask->data;
715+
int num_pad = 0;
716+
for (int64_t i = 0; i < max_seq_length; i++) {
717+
if (num_pad >= num_extra_padding) {
718+
break;
719+
}
720+
if (isinf(mask_data[i])) {
721+
mask_data[i] = 0;
722+
++num_pad;
723+
}
724+
}
725+
// LOG_DEBUG("PAD: %d", num_pad);
715726
}
716727

717728
// Generate positional embeddings
@@ -1073,7 +1084,7 @@ namespace Flux {
10731084
c_concat = to_backend(c_concat);
10741085
}
10751086
if (flux_params.is_chroma) {
1076-
flux.chroma_modify_mask_to_attend_padding(y, context->ne[1], 1);
1087+
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), 1);
10771088
// ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
10781089
range = arange(0, 344);
10791090
precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size());

0 commit comments

Comments
 (0)