Skip to content

Commit d18cdac

Browse files
committed
Adds head count fields to mask and bias parameters
Introduces h_mask and h_bias fields to track the number of heads in attention mask and bias structures respectively. Enables better head dimension management and validation in flash attention operations.
1 parent 1cf1385 commit d18cdac

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

csrc/flash_dmattn/src/flash.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ struct Mask_params {
5050
index_t mask_batch_stride; // Stride between batches of attention mask
5151
index_t mask_head_stride; // Stride between heads of attention mask
5252
index_t mask_row_stride; // Stride between rows of attention mask
53+
54+
// The number of heads in the mask.
55+
int h_mask;
5356
};
5457

5558
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -61,6 +64,9 @@ struct Bias_params {
6164
index_t bias_batch_stride; // Stride between batches of attention bias
6265
index_t bias_head_stride; // Stride between heads of attention bias
6366
index_t bias_row_stride; // Stride between rows of attention bias
67+
68+
// The number of heads in the bias.
69+
int h_bias;
6470
};
6571

6672
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)