Skip to content

Commit 3c3b588

Browse files
tomflindayzh119
andauthored
misc: Add the keyword "template" to member template specialization (#1246)
Add the keyword "template" to member template specialization appears after `.` or `->` in a post-fix expression which is a requirement in C++ standard Signed-off-by: chenwei.sun <[email protected]> Signed-off-by: chenwei.sun <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent e462997 commit 3c3b588

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

include/flashinfer/attention/hopper/prefill_sm90.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp
194194
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
195195
num_kv_tiles_outside_items_window, num_kv_tiles_prefix);
196196
} else {
197-
collective_mainloop.load<LEFT_SLIDING_WINDOW>(
197+
collective_mainloop.template load<LEFT_SLIDING_WINDOW>(
198198
mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
199199
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
200200
}

include/flashinfer/attention/prefill.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ __device__ __forceinline__ void produce_kv(smem_t<KTraits::SWIZZLE_MODE_KV> smem
295295
for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) {
296296
#pragma unroll
297297
for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) {
298-
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
298+
smem.template load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
299299
*smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j);
300300
*gptr += 8 * upcast_size<DTypeKV>();
301301
}
@@ -434,7 +434,7 @@ __device__ __forceinline__ void load_q_global_smem(
434434
const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q<KTraits>(tid.y);
435435

436436
if (get_warp_idx_kv<KTraits>(tid.z) == 0) {
437-
uint32_t q_smem_offset_w = q_smem->get_permuted_offset<UPCAST_STRIDE_Q>(
437+
uint32_t q_smem_offset_w = q_smem->template get_permuted_offset<UPCAST_STRIDE_Q>(
438438
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8);
439439

440440
#pragma unroll
@@ -449,8 +449,8 @@ __device__ __forceinline__ void load_q_global_smem(
449449
#pragma unroll
450450
for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) {
451451
// load q fragment from gmem to smem
452-
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr,
453-
q_idx < qo_upper_bound);
452+
q_smem->template load_128b_async<SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr,
453+
q_idx < qo_upper_bound);
454454
q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do);
455455
q_ptr += 8 * upcast_size<DTypeQ>();
456456
}
@@ -1258,12 +1258,12 @@ __device__ __forceinline__ void write_o_reg_gmem(
12581258
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]);
12591259

12601260
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
1261-
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
1261+
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>(
12621262
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16,
12631263
mma_d * 2 + lane_idx / 16);
12641264
o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
12651265
#else
1266-
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
1266+
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>(
12671267
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2);
12681268
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
12691269
((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] =
@@ -1275,7 +1275,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
12751275
}
12761276
}
12771277

1278-
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
1278+
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>(
12791279
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8);
12801280

12811281
#pragma unroll
@@ -1419,7 +1419,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice(
14191419
? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h
14201420
: o + (kv_head_idx * group_size) * o_stride_h;
14211421

1422-
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset<UPCAST_STRIDE_Q>(
1422+
uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset<UPCAST_STRIDE_Q>(
14231423
get_warp_idx_q<KTraits>(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16);
14241424
load_q_global_smem<KTraits>(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h,
14251425
group_size, &qo_smem, tid);

include/flashinfer/sampling.cuh

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
301301
}
302302
max_val = max(
303303
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
304-
.Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
304+
.template Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
305305
__syncthreads();
306306
}
307307
if (tx == 0) {
@@ -610,7 +610,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
610610
}
611611
float aggregate_local =
612612
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
613-
.Sum<VEC_SIZE>(prob_greater_than_threshold);
613+
.template Sum<VEC_SIZE>(prob_greater_than_threshold);
614614
if (tx == 0) {
615615
temp_storage->block_aggregate.value = aggregate_local;
616616
}
@@ -623,7 +623,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
623623
prob_greater_than_threshold, inclusive_cdf, temp_storage);
624624
} else {
625625
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
626-
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
626+
.template InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
627627

628628
__syncthreads();
629629
}
@@ -639,7 +639,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
639639
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
640640
#else
641641
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
642-
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
642+
.template FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
643643
#endif
644644
__syncthreads();
645645

@@ -775,7 +775,7 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType*
775775

776776
max_data +=
777777
BlockReduce<DataAndIndex<DType, IdType>, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage)
778-
.Sum<VEC_SIZE>(cur_data);
778+
.template Sum<VEC_SIZE>(cur_data);
779779
}
780780
if (tx == 0) {
781781
output[bx] = max_data.index;
@@ -1015,15 +1015,15 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
10151015
}
10161016

10171017
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
1018-
.Sum<VEC_SIZE>(probs_gt_pivot_0);
1018+
.template Sum<VEC_SIZE>(probs_gt_pivot_0);
10191019
if (tx == 0) {
10201020
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
10211021
}
10221022
__syncthreads();
10231023
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
10241024

10251025
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
1026-
.Sum<VEC_SIZE>(probs_gt_pivot_1);
1026+
.template Sum<VEC_SIZE>(probs_gt_pivot_1);
10271027
if (tx == 0) {
10281028
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
10291029
}
@@ -1676,12 +1676,12 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
16761676

16771677
aggregate_gt_pivot_0 +=
16781678
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1679-
.Sum<VEC_SIZE>(probs_gt_pivot_0);
1679+
.template Sum<VEC_SIZE>(probs_gt_pivot_0);
16801680
__syncthreads();
16811681

16821682
aggregate_gt_pivot_1 +=
16831683
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1684-
.Sum<VEC_SIZE>(probs_gt_pivot_1);
1684+
.template Sum<VEC_SIZE>(probs_gt_pivot_1);
16851685
__syncthreads();
16861686
}
16871687
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
@@ -1917,12 +1917,12 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
19171917

19181918
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
19191919
temp_storage.block_prim.reduce_value_count)
1920-
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
1920+
.template Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
19211921
__syncthreads();
19221922

19231923
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
19241924
temp_storage.block_prim.reduce_value_count)
1925-
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
1925+
.template Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
19261926
__syncthreads();
19271927
}
19281928
min_gt_low =

0 commit comments

Comments
 (0)