@@ -27,27 +27,28 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in
2727 }
2828 }
2929
30- extern __shared__ int non_zero_counts [];
31- non_zero_counts[ threadIdx . x ] = count;
30+ extern __shared__ int shared_row_indices [];
31+ shared_row_indices[row + 1 ] = count;
3232 __syncthreads ();
3333
3434 // The first thread will calculate the accumulated partial sum of non-zero counts.
35+ // The result is csr_row_indices stored in shared memory.
3536 if (row == 0 ) {
37+ shared_row_indices[0 ] = 0 ;
3638 for (int i = 1 ; i < num_rows; i++) {
37- non_zero_counts[i ] += non_zero_counts[i - 1 ];
39+ shared_row_indices[i + 1 ] += shared_row_indices[i ];
3840 }
41+
42+ // The first thread outputs the last element.
43+ csr_row_indices[num_rows] = shared_row_indices[num_rows];
3944 }
4045 __syncthreads ();
4146
42- // The starting index of current row in csr_col_indices
43- int offset = (row == 0 ) ? 0 : non_zero_counts [row - 1 ];
47+ // The starting index of current row in csr_col_indices
48+ int offset = shared_row_indices [row];
4449
4550 // Output row indices.
4651 csr_row_indices[row] = offset;
47- if (row == 0 ) {
48- // The first thread output the last element.
49- csr_row_indices[num_rows] = non_zero_counts[num_rows - 1 ];
50- }
5152
5253 for (int col = 0 ; col < num_cols; col++) {
5354 if (mask[row * num_cols + col] == 1 ) {
@@ -60,6 +61,59 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in
6061 // The last element of csr_row_indices is the total number of non-zero elements.
6162}
6263
64+ __global__ void MaskToCSR_Large (const int * mask,
65+ int * csr_row_indices,
66+ int * csr_col_indices,
67+ int num_rows,
68+ int num_cols,
69+ int rows_per_thread // Each thread handles multiple rows
70+ ) {
71+ extern __shared__ int shared_row_indices[];
72+
73+ // Update input and output data pointers to the start of current head
74+ int head = blockIdx .x ;
75+ mask += head * num_rows * num_cols;
76+ csr_row_indices += head * (num_rows + 1 );
77+ csr_col_indices += head * num_rows * num_cols;
78+
79+ int tid = threadIdx .x ;
80+ for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1 ) * rows_per_thread; row++) {
81+ int count = 0 ;
82+ for (int col = 0 ; col < num_cols; col++) {
83+ if (mask[row * num_cols + col] == 1 ) {
84+ count++;
85+ }
86+ }
87+ shared_row_indices[row + 1 ] = count;
88+ }
89+
90+ __syncthreads ();
91+
92+ // The first thread will calculate the accumulated partial sum of non-zero counts.
93+ if (tid == 0 ) {
94+ shared_row_indices[0 ] = 0 ;
95+ for (int i = 1 ; i < num_rows; i++) {
96+ shared_row_indices[i + 1 ] += shared_row_indices[i];
97+ }
98+
99+ csr_row_indices[num_rows] = shared_row_indices[num_rows];
100+ }
101+
102+ __syncthreads ();
103+
104+ for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1 ) * rows_per_thread; row++) {
105+ int offset = shared_row_indices[row];
106+ csr_row_indices[row] = offset;
107+
108+ for (int col = 0 ; col < num_cols; col++) {
109+ if (mask[row * num_cols + col] == 1 ) {
110+ csr_col_indices[offset] = col;
111+ offset++;
112+ }
113+ }
114+ }
115+ }
116+
63117void ConvertMaskToCSR (cudaStream_t stream,
64118 const int * mask, // input mask with shape (num_layout, num_rows, num_cols)
65119 int num_layout, // number of layouts
@@ -68,15 +122,17 @@ void ConvertMaskToCSR(cudaStream_t stream,
68122 int * csr_row_indices, // output CSR row indices
69123 int * csr_col_indices, // output CSR column indices
70124 int max_threads_per_block) {
71- int threads_per_block = (num_rows + 31 ) / 32 * 32 ;
72-
73- // Each thread handle one row. The kernel assumes that all rows of one head can be handled in one block.
74- if (threads_per_block > max_threads_per_block) {
75- ORT_THROW (" num_rows is too large: num_rows=" , num_rows, " , max_threads_per_block=" , max_threads_per_block);
125+ if (num_rows <= max_threads_per_block) {
126+ // Each thread handle one row.
127+ MaskToCSR<<<num_layout, num_rows, (num_rows + 1 ) * sizeof (int ), stream>>> (
128+ mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
129+ } else {
130+ // Each thread will handle multiple rows when number of rows > max_threads_per_block.
131+ // For example 128K length with sparse block size 64 will have 2048 rows. Each thread will handle 2 rows.
132+ int rows_per_thread = (num_rows + max_threads_per_block - 1 ) / max_threads_per_block;
133+ MaskToCSR_Large<<<num_layout, max_threads_per_block, (num_rows + 1 ) * sizeof (int ), stream>>> (
134+ mask, csr_row_indices, csr_col_indices, num_rows, num_cols, rows_per_thread);
76135 }
77-
78- MaskToCSR<<<num_layout, threads_per_block, threads_per_block * sizeof (int ), stream>>> (
79- mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
80136}
81137
82138} // namespace cuda
0 commit comments