@@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate(
9898 const void *block_tables, const void *seq_lens, const void *alibi_slopes,
9999 void *stream_) const {
100100 cudaStream_t stream = (cudaStream_t)stream_;
101+
102+ #define LAUNCH_HEADSIZE_BLOCKSIZE (__H_SIZE, __B_SIZE ) \
103+ launchKernel<__H_SIZE, __B_SIZE>( \
104+ out, q, k_cache, v_cache, _info.dtype , block_tables, seq_lens, alibi_slopes, \
105+ _info.num_heads , _info.num_seqs , \
106+ _info.num_kv_heads , _info.scale , _info.max_num_blocks_per_seq , _info.block_size , \
107+ _info.q_stride , _info.kv_block_stride , _info.kv_head_stride , _info.o_stride , \
108+ stream);
109+
110+ #define SWITCH_HEAD_SIZE (__B_SIZE ) \
111+ switch (_info.head_size ) { \
112+ case 16 : \
113+ LAUNCH_HEADSIZE_BLOCKSIZE (16 , __B_SIZE) \
114+ break ; \
115+ case 32 : \
116+ LAUNCH_HEADSIZE_BLOCKSIZE (32 , __B_SIZE) \
117+ break ; \
118+ case 64 : \
119+ LAUNCH_HEADSIZE_BLOCKSIZE (64 , __B_SIZE) \
120+ break ; \
121+ case 128 : \
122+ LAUNCH_HEADSIZE_BLOCKSIZE (128 , __B_SIZE) \
123+ break ; \
124+ case 256 : \
125+ LAUNCH_HEADSIZE_BLOCKSIZE (256 , __B_SIZE) \
126+ break ; \
127+ default : \
128+ return INFINI_STATUS_BAD_TENSOR_SHAPE; \
129+ }
130+
101131 if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_1024) {
102- if (_info.head_size == 128 ) {
103- launchKernel<128 , CUDA_BLOCK_SIZE_1024>(
104- out, q, k_cache, v_cache, _info.dtype , block_tables, seq_lens, alibi_slopes,
105- _info.num_heads , _info.num_seqs ,
106- _info.num_kv_heads , _info.scale , _info.max_num_blocks_per_seq , _info.block_size ,
107- _info.q_stride , _info.kv_block_stride , _info.kv_head_stride , _info.o_stride ,
108- stream);
109- }
132+ SWITCH_HEAD_SIZE (CUDA_BLOCK_SIZE_1024)
110133 } else if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_512) {
111- if (_info.head_size == 128 ) {
112- launchKernel<128 , CUDA_BLOCK_SIZE_512>(
113- out, q, k_cache, v_cache, _info.dtype , block_tables, seq_lens, alibi_slopes,
114- _info.num_heads , _info.num_seqs ,
115- _info.num_kv_heads , _info.scale , _info.max_num_blocks_per_seq , _info.block_size ,
116- _info.q_stride , _info.kv_block_stride , _info.kv_head_stride , _info.o_stride ,
117- stream);
118- }
134+ SWITCH_HEAD_SIZE (CUDA_BLOCK_SIZE_512)
119135 } else if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_4096) {
120- if (_info.head_size == 128 ) {
121- launchKernel<128 , CUDA_BLOCK_SIZE_4096>(
122- out, q, k_cache, v_cache, _info.dtype , block_tables, seq_lens, alibi_slopes,
123- _info.num_heads , _info.num_seqs ,
124- _info.num_kv_heads , _info.scale , _info.max_num_blocks_per_seq , _info.block_size ,
125- _info.q_stride , _info.kv_block_stride , _info.kv_head_stride , _info.o_stride ,
126- stream);
127- }
136+ SWITCH_HEAD_SIZE (CUDA_BLOCK_SIZE_4096)
128137 } else {
129138 return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
130139 }
131140
141+ #undef LAUNCH_HEADSIZE_BLOCKSIZE
142+ #undef SWITCH_HEAD_SIZE
143+
132144 return INFINI_STATUS_SUCCESS;
133145}
134146
0 commit comments