@@ -1550,35 +1550,38 @@ at::Tensor rope_qkv_varseq_prefill(
1550
1550
static_cast <int64_t *>(actual_batch_size.value ().data_ptr ());
1551
1551
}
1552
1552
if (cache_K.dtype () == at::kBFloat16 ) {
1553
- rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>
1554
- <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1555
- XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1556
- XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1557
- XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1558
- cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1559
- cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1560
- XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1561
- varseq_batch.data_ptr <int32_t >(),
1562
- varseq_seqpos
1563
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1564
- theta,
1565
- 0 ,
1566
- 0 ,
1567
- 0 ,
1568
- block_tables_ptr,
1569
- page_size,
1570
- block_tables_b_stride,
1571
- varseq_cache_seqpos_
1572
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1573
- actual_batch_size_ptr,
1574
- rope_scaling,
1575
- old_context_len,
1576
- scaling_factor,
1577
- lo_freq_factor,
1578
- hi_freq_factor,
1579
- write_k_back,
1580
- update_kv);
1581
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1553
+ FBGEMM_LAUNCH_KERNEL (
1554
+ (rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>),
1555
+ blocks,
1556
+ threads,
1557
+ 0 ,
1558
+ at::cuda::getCurrentCUDAStream (),
1559
+ XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1560
+ XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1561
+ XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1562
+ cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1563
+ cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1564
+ XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1565
+ varseq_batch.data_ptr <int32_t >(),
1566
+ varseq_seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1567
+ theta,
1568
+ 0 ,
1569
+ 0 ,
1570
+ 0 ,
1571
+ block_tables_ptr,
1572
+ page_size,
1573
+ block_tables_b_stride,
1574
+ varseq_cache_seqpos_
1575
+ .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1576
+ actual_batch_size_ptr,
1577
+ rope_scaling,
1578
+ old_context_len,
1579
+ scaling_factor,
1580
+ lo_freq_factor,
1581
+ hi_freq_factor,
1582
+ write_k_back,
1583
+ update_kv);
1584
+
1582
1585
} else {
1583
1586
auto num_groups_ = num_groups ? num_groups.value () : 1 ;
1584
1587
auto varseq_batch_ = varseq_batch.data_ptr <int32_t >();
@@ -1767,33 +1770,38 @@ at::Tensor xpos_qkv_varseq_prefill(
1767
1770
static_cast <int64_t *>(actual_batch_size.value ().data_ptr ());
1768
1771
}
1769
1772
if (cache_K.dtype () == at::kBFloat16 ) {
1770
- rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>
1771
- <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1772
- XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1773
- XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1774
- XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1775
- cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1776
- cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1777
- XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1778
- varseq_batch.data_ptr <int32_t >(),
1779
- varseq_seqpos
1780
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1781
- theta,
1782
- gamma,
1783
- scale_base,
1784
- exponent_offset,
1785
- block_tables_ptr,
1786
- page_size,
1787
- block_tables_b_stride,
1788
- varseq_cache_seqpos_
1789
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1790
- actual_batch_size_ptr,
1791
- rope_scaling,
1792
- old_context_len,
1793
- scaling_factor,
1794
- lo_freq_factor,
1795
- hi_freq_factor);
1796
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1773
+ FBGEMM_LAUNCH_KERNEL (
1774
+ (rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>),
1775
+ blocks,
1776
+ threads,
1777
+ 0 ,
1778
+ at::cuda::getCurrentCUDAStream (),
1779
+ XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1780
+ XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1781
+ XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1782
+ cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1783
+ cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1784
+ XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1785
+ varseq_batch.data_ptr <int32_t >(),
1786
+ varseq_seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1787
+ theta,
1788
+ gamma,
1789
+ scale_base,
1790
+ exponent_offset,
1791
+ block_tables_ptr,
1792
+ page_size,
1793
+ block_tables_b_stride,
1794
+ varseq_cache_seqpos_
1795
+ .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1796
+ actual_batch_size_ptr,
1797
+ rope_scaling,
1798
+ old_context_len,
1799
+ scaling_factor,
1800
+ lo_freq_factor,
1801
+ hi_freq_factor,
1802
+ false ,
1803
+ true );
1804
+
1797
1805
} else {
1798
1806
auto num_groups_ = num_groups ? num_groups.value () : 1 ;
1799
1807
auto varseq_batch_ = varseq_batch.data_ptr <int32_t >();
@@ -1935,34 +1943,37 @@ at::Tensor rope_qkv_decoding(
1935
1943
}
1936
1944
auto cache_seqpos_ = cache_seqpos.value_or (seqpos);
1937
1945
if (cache_K.dtype () == at::kBFloat16 ) {
1938
- rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>
1939
- <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1940
- XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1941
- XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1942
- XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1943
- cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1944
- cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1945
- XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1946
- batch.has_value () ? batch.value ().data_ptr <int32_t >() : nullptr ,
1947
- seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1948
- theta,
1949
- 0 ,
1950
- 0 ,
1951
- 0 ,
1952
- block_tables_ptr,
1953
- page_size,
1954
- block_tables_b_stride,
1955
- cache_seqpos_
1956
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1957
- actual_batch_size_ptr,
1958
- rope_scaling,
1959
- old_context_len,
1960
- scaling_factor,
1961
- lo_freq_factor,
1962
- hi_freq_factor,
1963
- false ,
1964
- update_kv);
1965
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1946
+ FBGEMM_LAUNCH_KERNEL (
1947
+ (rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>),
1948
+ blocks,
1949
+ threads,
1950
+ 0 ,
1951
+ at::cuda::getCurrentCUDAStream (),
1952
+ XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1953
+ XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1954
+ XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1955
+ cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1956
+ cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1957
+ XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1958
+ batch.has_value () ? batch.value ().data_ptr <int32_t >() : nullptr ,
1959
+ seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1960
+ theta,
1961
+ 0 ,
1962
+ 0 ,
1963
+ 0 ,
1964
+ block_tables_ptr,
1965
+ page_size,
1966
+ block_tables_b_stride,
1967
+ cache_seqpos_.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1968
+ actual_batch_size_ptr,
1969
+ rope_scaling,
1970
+ old_context_len,
1971
+ scaling_factor,
1972
+ lo_freq_factor,
1973
+ hi_freq_factor,
1974
+ false ,
1975
+ update_kv);
1976
+
1966
1977
} else {
1967
1978
auto seqpos_ =
1968
1979
seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>();
@@ -2142,32 +2153,37 @@ at::Tensor xpos_qkv_decoding(
2142
2153
}
2143
2154
auto cache_seqpos_ = cache_seqpos.value_or (seqpos);
2144
2155
if (cache_K.dtype () == at::kBFloat16 ) {
2145
- rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>
2146
- <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
2147
- XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2148
- XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2149
- XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2150
- cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2151
- cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2152
- XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2153
- batch.has_value () ? batch.value ().data_ptr <int32_t >() : nullptr ,
2154
- seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
2155
- theta,
2156
- gamma,
2157
- scale_base,
2158
- exponent_offset,
2159
- block_tables_ptr,
2160
- page_size,
2161
- block_tables_b_stride,
2162
- cache_seqpos_
2163
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
2164
- actual_batch_size_ptr,
2165
- rope_scaling,
2166
- old_context_len,
2167
- scaling_factor,
2168
- lo_freq_factor,
2169
- hi_freq_factor);
2170
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
2156
+ FBGEMM_LAUNCH_KERNEL (
2157
+ (rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>),
2158
+ blocks,
2159
+ threads,
2160
+ 0 ,
2161
+ at::cuda::getCurrentCUDAStream (),
2162
+ XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2163
+ XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2164
+ XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2165
+ cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2166
+ cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2167
+ XQ_O.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
2168
+ batch.has_value () ? batch.value ().data_ptr <int32_t >() : nullptr ,
2169
+ seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
2170
+ theta,
2171
+ gamma,
2172
+ scale_base,
2173
+ exponent_offset,
2174
+ block_tables_ptr,
2175
+ page_size,
2176
+ block_tables_b_stride,
2177
+ cache_seqpos_.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
2178
+ actual_batch_size_ptr,
2179
+ rope_scaling,
2180
+ old_context_len,
2181
+ scaling_factor,
2182
+ lo_freq_factor,
2183
+ hi_freq_factor,
2184
+ false ,
2185
+ true );
2186
+
2171
2187
} else {
2172
2188
auto num_groups_ = num_groups ? num_groups.value () : 1 ;
2173
2189
auto seqpos_ =
0 commit comments