24
24
#include < numeric>
25
25
#include < random>
26
26
#include < sstream>
27
+ #include < type_traits>
27
28
28
29
#include " tensorrt_llm/common/memoryUtils.h"
29
30
#include " tensorrt_llm/common/workspace.h"
@@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken(
865
866
// ============================== Infer GEMM sizes =================================
866
867
// TODO Could linear search be better for small # experts
867
868
template <class T >
868
- __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices,
869
+ __device__ inline int64_t findTotalEltsLessThanTarget_v1 (T const * sorted_indices,
869
870
int64_t const arr_length, T const target) {
870
871
int64_t low = 0 , high = arr_length - 1 , target_location = -1 ;
871
872
while (low <= high) {
@@ -881,6 +882,49 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
881
882
return target_location + 1 ;
882
883
}
883
884
885
+ template <class T >
886
+ __device__ inline int64_t findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t const arr_length, T const target) {
887
+ constexpr int ARR_LENGTH_CONST = 128 ;
888
+ if (arr_length != ARR_LENGTH_CONST) {
889
+ asm (" trap;" );
890
+ }
891
+
892
+ constexpr unsigned full_mask = 0xffffffffu ;
893
+ constexpr int WARP_SZ = 32 ;
894
+ const int lane_id = threadIdx .x & (WARP_SZ - 1 );
895
+
896
+ int local_count = 0 ;
897
+ #pragma unroll
898
+ for (int k = 0 ; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
899
+ const int idx = lane_id + k * WARP_SZ;
900
+ T v = sorted_indices[idx];
901
+ local_count += (v < target) ? 1 : 0 ;
902
+ }
903
+
904
+ #pragma unroll
905
+ for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
906
+ local_count += __shfl_down_sync (full_mask, local_count, offset);
907
+ }
908
+ int total = __shfl_sync (full_mask, local_count, 0 );
909
+
910
+ return (int64_t )total;
911
+ }
912
+
913
+ template <class T >
914
+ __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices, int64_t const arr_length, T const target) {
915
+ // return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916
+
917
+ return findTotalEltsLessThanTarget_v2 (sorted_indices, arr_length, target);
918
+
919
+ // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920
+ // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
921
+ // if (out_v1 != out_v2) {
922
+ // printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
923
+ // asm("trap;");
924
+ // }
925
+ // return out_v1;
926
+ }
927
+
884
928
template <class T >
885
929
using sizeof_bits = cutlass::sizeof_bits<
886
930
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t <T>>::type>;
@@ -1414,20 +1458,23 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) {
1414
1458
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the
1415
1459
// source matrix, we simply take the modulus of the expanded index.
1416
1460
1417
- constexpr static int EXPAND_THREADS_PER_BLOCK = 256 ;
1461
+ constexpr static int EXPAND_THREADS_PER_BLOCK = 128 ;
1418
1462
1419
1463
template <class InputActivationsType , class ExpandedActivationsType ,
1420
1464
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1421
1465
bool PRE_QUANT_AWQ>
1422
1466
__global__ void expandInputRowsKernel (
1423
1467
InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
1424
1468
float const * unpermuted_scales, float * permuted_scales,
1425
- int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size ,
1469
+ int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_ ,
1426
1470
int64_t const k, float const * fc1_act_global_scale, bool use_per_expert_act_scale,
1427
1471
int64_t const * expert_first_token_offset,
1428
1472
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
1429
1473
TmaWarpSpecializedGroupedGemmInput::ElementSF const * input_sf,
1430
1474
int64_t const num_experts_per_node, InputActivationsType const * prequant_scales = nullptr ) {
1475
+ constexpr int hidden_size = 7168 ;
1476
+ if (hidden_size != hidden_size_real_) { asm (" trap;" ); }
1477
+
1431
1478
static_assert (BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
1432
1479
!PRE_QUANT_AWQ,
1433
1480
" AWQ and Block Scaling are mutually exclusive" );
@@ -1503,8 +1550,8 @@ __global__ void expandInputRowsKernel(
1503
1550
permuted_row * hidden_size / ELEM_PER_THREAD;
1504
1551
1505
1552
int64_t const start_offset = threadIdx .x ;
1506
- int64_t const stride = EXPAND_THREADS_PER_BLOCK;
1507
- int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1553
+ constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK;
1554
+ constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1508
1555
assert (hidden_size % ELEM_PER_THREAD == 0 );
1509
1556
assert (hidden_size % VecSize == 0 );
1510
1557
@@ -1519,6 +1566,7 @@ __global__ void expandInputRowsKernel(
1519
1566
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1 .0f ;
1520
1567
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
1521
1568
1569
+ #pragma unroll
1522
1570
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
1523
1571
auto in_vec = source_row_ptr[elem_index];
1524
1572
if constexpr (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1650,7 +1698,7 @@ void expandInputRowsKernelLauncher(
1650
1698
1651
1699
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
1652
1700
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1653
- int64_t const blocks = std::min (smCount * 8 , std::max (num_rows * k, num_padding_tokens));
1701
+ int64_t const blocks = std::min (smCount * 16 , std::max (num_rows * k, num_padding_tokens));
1654
1702
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
1655
1703
1656
1704
auto func = [&]() {
@@ -1748,11 +1796,20 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
1748
1796
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
1749
1797
// connection.
1750
1798
template <typename OutputType, class GemmOutputType , class ScaleBiasType , ScaleMode SCALE_MODE>
1751
- __global__ void finalizeMoeRoutingKernel (
1799
+ __global__
1800
+ __maxnreg__ (64 )
1801
+ void finalizeMoeRoutingKernel (
1752
1802
GemmOutputType const * expanded_permuted_rows, OutputType* reduced_unpermuted_output,
1753
1803
ScaleBiasType const * bias, float const * scales, int const * unpermuted_row_to_permuted_row,
1754
- int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token ,
1804
+ int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_ ,
1755
1805
int const num_experts_per_node, int const start_expert_id) {
1806
+ if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_same_v<OutputType, __nv_bfloat16>)) {
1807
+ printf (" finalizeMoeRoutingKernel see unsupported dtype\n " );
1808
+ asm (" trap;" );
1809
+ } else {
1810
+ constexpr int experts_per_token = 8 ;
1811
+ if (experts_per_token != experts_per_token_real_) { asm (" trap;" ); }
1812
+
1756
1813
int64_t const original_row = blockIdx .x ;
1757
1814
int64_t const num_rows = gridDim .x ;
1758
1815
auto const offset = original_row * orig_cols;
@@ -1784,43 +1841,67 @@ __global__ void finalizeMoeRoutingKernel(
1784
1841
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
1785
1842
ComputeElem thread_output;
1786
1843
thread_output.fill (0 );
1844
+
1845
+ int4 input_val_buf[experts_per_token];
1846
+ uint32_t enable_input_buf = 0 ;
1847
+
1848
+ #pragma unroll
1787
1849
for (int k_idx = 0 ; k_idx < experts_per_token; ++k_idx) {
1788
1850
int64_t const k_offset = original_row * experts_per_token + k_idx;
1789
1851
int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id;
1790
- if (expert_id < 0 || expert_id >= num_experts_per_node) {
1791
- continue ;
1792
- }
1793
1852
1794
1853
int64_t const expanded_original_row = original_row + k_idx * num_rows;
1795
1854
int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
1796
1855
1856
+ if (expert_id < 0 || expert_id >= num_experts_per_node) {
1857
+ continue ;
1858
+ }
1859
+
1797
1860
int64_t expanded_rows = num_rows * experts_per_token;
1798
1861
if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
1799
1862
continue ;
1800
1863
}
1801
1864
1802
- float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1 .f : scales[k_offset];
1803
-
1804
1865
auto const * expanded_permuted_rows_row_ptr =
1805
1866
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
1806
1867
1807
- ComputeElem expert_result =
1808
- arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1809
- if (bias) {
1810
- auto const * bias_ptr = bias_v + expert_id * num_elems_in_col;
1811
- expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1812
- }
1868
+ // ComputeElem expert_result =
1869
+ // arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1870
+ static_assert (sizeof (expanded_permuted_rows_row_ptr[0 ]) == sizeof (int4 ));
1871
+ input_val_buf[k_idx] = *reinterpret_cast <const int4 *>(expanded_permuted_rows_row_ptr + elem_index);
1872
+ enable_input_buf |= 1 << k_idx;
1873
+ }
1874
+
1875
+ #pragma unroll
1876
+ for (int k_idx = 0 ; k_idx < experts_per_token; ++k_idx) {
1877
+ if (not (enable_input_buf & (1 << k_idx))) continue ;
1878
+
1879
+ int64_t const k_offset = original_row * experts_per_token + k_idx;
1880
+ float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1 .f : scales[k_offset];
1881
+
1882
+ int4 input_val = input_val_buf[k_idx];
1883
+ ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(*reinterpret_cast <const InputElem*>(&input_val));
1884
+ // if (bias) {
1885
+ // auto const* bias_ptr = bias_v + expert_id * num_elems_in_col;
1886
+ // expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1887
+ // }
1813
1888
1814
1889
thread_output = thread_output + row_scale * expert_result;
1815
1890
}
1816
1891
1817
- OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1818
- reduced_row_ptr_v[elem_index] = output_elem;
1892
+ // OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1893
+ // reduced_row_ptr_v[elem_index] = output_elem;
1894
+ // TODO alignment issue?
1895
+ __align__ (16 ) OutputElem output_elem_original = arrayConvert<ComputeElem, OutputElem>(thread_output);
1896
+ int4 output_elem = *reinterpret_cast <int4 *>(&output_elem_original);
1897
+ static_assert (sizeof (reduced_row_ptr_v[0 ]) == sizeof (int4 ));
1898
+ *reinterpret_cast <int4 *>(reduced_row_ptr_v + elem_index) = output_elem;
1819
1899
}
1820
1900
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1821
1901
asm volatile (" griddepcontrol.launch_dependents;" );
1822
1902
#endif
1823
1903
}
1904
+ }
1824
1905
1825
1906
// Final kernel to unpermute and scale
1826
1907
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
@@ -2130,10 +2211,13 @@ template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
2130
2211
__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result,
2131
2212
float const * fp8_quant, ScaleBiasType const * bias_ptr,
2132
2213
bool bias_is_broadcast, int64_t const * expert_first_token_offset,
2133
- int num_experts_per_node, int64_t inter_size ,
2214
+ int num_experts_per_node, int64_t inter_size_real_ ,
2134
2215
float const * fc2_act_global_scale, bool use_per_expert_act_scale,
2135
2216
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
2136
2217
ActivationParams activation_params) {
2218
+ constexpr int inter_size = 2048 ;
2219
+ if (inter_size != inter_size_real_) { asm (" trap;" ); }
2220
+
2137
2221
#ifdef ENABLE_FP4
2138
2222
constexpr bool IsNVFP4 =
2139
2223
std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2218,16 +2302,18 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
2218
2302
auto output_vec = reinterpret_cast <OutputElem*>(safe_inc_ptr (output, output_offset));
2219
2303
auto bias_ptr_vec = reinterpret_cast <BiasElem const *>(bias_ptr + bias_offset);
2220
2304
int64_t const start_offset = tid;
2221
- int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
2305
+ constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK;
2222
2306
assert (inter_size % ACTIVATION_ELEM_PER_THREAD == 0 );
2223
- int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2307
+ constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2224
2308
assert (gated_off % ACTIVATION_ELEM_PER_THREAD == 0 );
2225
2309
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
2226
2310
2227
2311
ActFn fn{};
2228
2312
fn.alpha = gate_alpha;
2229
2313
fn.beta = gate_beta;
2230
2314
fn.limit = gate_limit;
2315
+
2316
+ #pragma unroll
2231
2317
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
2232
2318
auto fc1_value =
2233
2319
arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
0 commit comments