File tree Expand file tree Collapse file tree 1 file changed +13
-2
lines changed
Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -711,7 +711,18 @@ namespace Flux {
711711 }
712712
713713 void chroma_modify_mask_to_attend_padding (struct ggml_tensor * mask, int max_seq_length, int num_extra_padding = 8 ) {
714- // TODO: implement
714+ float * mask_data = (float *)mask->data ;
715+ int num_pad = 0 ;
716+ for (int64_t i = 0 ; i < max_seq_length; i++) {
717+ if (num_pad >= num_extra_padding) {
718+ break ;
719+ }
720+ if (isinf (mask_data[i])) {
721+ mask_data[i] = 0 ;
722+ ++num_pad;
723+ }
724+ }
725+ // LOG_DEBUG("PAD: %d", num_pad);
715726 }
716727
717728 // Generate positional embeddings
@@ -1073,7 +1084,7 @@ namespace Flux {
10731084 c_concat = to_backend (c_concat);
10741085 }
10751086 if (flux_params.is_chroma ) {
1076- flux.chroma_modify_mask_to_attend_padding (y, context-> ne [ 1 ] , 1 );
1087+ flux.chroma_modify_mask_to_attend_padding (y, ggml_nelements (y) , 1 );
10771088 // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
10781089 range = arange (0 , 344 );
10791090 precompute_arange = ggml_new_tensor_1d (compute_ctx, GGML_TYPE_F32, range.size ());
You can’t perform that action at this time.
0 commit comments