24
24
#include < numeric>
25
25
#include < random>
26
26
#include < sstream>
27
- #include < type_traits>
28
27
29
28
#include " tensorrt_llm/common/memoryUtils.h"
30
29
#include " tensorrt_llm/common/workspace.h"
@@ -866,7 +865,7 @@ void threeStepBuildExpertMapsSortFirstToken(
866
865
// ============================== Infer GEMM sizes =================================
867
866
// TODO Could linear search be better for small # experts
868
867
template <class T >
869
- __device__ inline int64_t findTotalEltsLessThanTarget_v1 (T const * sorted_indices,
868
+ __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices,
870
869
int64_t const arr_length, T const target) {
871
870
int64_t low = 0 , high = arr_length - 1 , target_location = -1 ;
872
871
while (low <= high) {
@@ -882,49 +881,6 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882
881
return target_location + 1 ;
883
882
}
884
883
885
- template <class T >
886
- __device__ inline int64_t findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t const arr_length, T const target) {
887
- constexpr int ARR_LENGTH_CONST = 128 ;
888
- if (arr_length != ARR_LENGTH_CONST) {
889
- asm (" trap;" );
890
- }
891
-
892
- constexpr unsigned full_mask = 0xffffffffu ;
893
- constexpr int WARP_SZ = 32 ;
894
- const int lane_id = threadIdx .x & (WARP_SZ - 1 );
895
-
896
- int local_count = 0 ;
897
- #pragma unroll
898
- for (int k = 0 ; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
899
- const int idx = lane_id + k * WARP_SZ;
900
- T v = sorted_indices[idx];
901
- local_count += (v < target) ? 1 : 0 ;
902
- }
903
-
904
- #pragma unroll
905
- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
906
- local_count += __shfl_down_sync (full_mask, local_count, offset);
907
- }
908
- int total = __shfl_sync (full_mask, local_count, 0 );
909
-
910
- return (int64_t )total;
911
- }
912
-
913
- template <class T >
914
- __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices, int64_t const arr_length, T const target) {
915
- // return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916
-
917
- return findTotalEltsLessThanTarget_v2 (sorted_indices, arr_length, target);
918
-
919
- // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920
- // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
921
- // if (out_v1 != out_v2) {
922
- // printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
923
- // asm("trap;");
924
- // }
925
- // return out_v1;
926
- }
927
-
928
884
template <class T >
929
885
using sizeof_bits = cutlass::sizeof_bits<
930
886
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t <T>>::type>;
@@ -1458,23 +1414,20 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) {
1458
1414
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the
1459
1415
// source matrix, we simply take the modulus of the expanded index.
1460
1416
1461
- constexpr static int EXPAND_THREADS_PER_BLOCK = 128 ;
1417
+ constexpr static int EXPAND_THREADS_PER_BLOCK = 256 ;
1462
1418
1463
1419
template <class InputActivationsType , class ExpandedActivationsType ,
1464
1420
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1465
1421
bool PRE_QUANT_AWQ>
1466
1422
__global__ void expandInputRowsKernel (
1467
1423
InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
1468
1424
float const * unpermuted_scales, float * permuted_scales,
1469
- int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_ ,
1425
+ int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size ,
1470
1426
int64_t const k, float const * fc1_act_global_scale, bool use_per_expert_act_scale,
1471
1427
int64_t const * expert_first_token_offset,
1472
1428
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
1473
1429
TmaWarpSpecializedGroupedGemmInput::ElementSF const * input_sf,
1474
1430
int64_t const num_experts_per_node, InputActivationsType const * prequant_scales = nullptr ) {
1475
- constexpr int hidden_size = 7168 ;
1476
- if (hidden_size != hidden_size_real_) { asm (" trap;" ); }
1477
-
1478
1431
static_assert (BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
1479
1432
!PRE_QUANT_AWQ,
1480
1433
" AWQ and Block Scaling are mutually exclusive" );
@@ -1550,8 +1503,8 @@ __global__ void expandInputRowsKernel(
1550
1503
permuted_row * hidden_size / ELEM_PER_THREAD;
1551
1504
1552
1505
int64_t const start_offset = threadIdx .x ;
1553
- constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK;
1554
- constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1506
+ int64_t const stride = EXPAND_THREADS_PER_BLOCK;
1507
+ int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1555
1508
assert (hidden_size % ELEM_PER_THREAD == 0 );
1556
1509
assert (hidden_size % VecSize == 0 );
1557
1510
@@ -1566,7 +1519,6 @@ __global__ void expandInputRowsKernel(
1566
1519
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1 .0f ;
1567
1520
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
1568
1521
1569
- #pragma unroll
1570
1522
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
1571
1523
auto in_vec = source_row_ptr[elem_index];
1572
1524
if constexpr (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1698,7 +1650,7 @@ void expandInputRowsKernelLauncher(
1698
1650
1699
1651
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
1700
1652
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1701
- int64_t const blocks = std::min (smCount * 16 , std::max (num_rows * k, num_padding_tokens));
1653
+ int64_t const blocks = std::min (smCount * 8 , std::max (num_rows * k, num_padding_tokens));
1702
1654
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
1703
1655
1704
1656
auto func = [&]() {
@@ -1796,20 +1748,11 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
1796
1748
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
1797
1749
// connection.
1798
1750
template <typename OutputType, class GemmOutputType , class ScaleBiasType , ScaleMode SCALE_MODE>
1799
- __global__
1800
- __maxnreg__ (64 )
1801
- void finalizeMoeRoutingKernel (
1751
+ __global__ void finalizeMoeRoutingKernel (
1802
1752
GemmOutputType const * expanded_permuted_rows, OutputType* reduced_unpermuted_output,
1803
1753
ScaleBiasType const * bias, float const * scales, int const * unpermuted_row_to_permuted_row,
1804
- int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_ ,
1754
+ int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token ,
1805
1755
int const num_experts_per_node, int const start_expert_id) {
1806
- if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_same_v<OutputType, __nv_bfloat16>)) {
1807
- printf (" finalizeMoeRoutingKernel see unsupported dtype\n " );
1808
- asm (" trap;" );
1809
- } else {
1810
- constexpr int experts_per_token = 8 ;
1811
- if (experts_per_token != experts_per_token_real_) { asm (" trap;" ); }
1812
-
1813
1756
int64_t const original_row = blockIdx .x ;
1814
1757
int64_t const num_rows = gridDim .x ;
1815
1758
auto const offset = original_row * orig_cols;
@@ -1841,67 +1784,43 @@ if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_sam
1841
1784
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
1842
1785
ComputeElem thread_output;
1843
1786
thread_output.fill (0 );
1844
-
1845
- int4 input_val_buf[experts_per_token];
1846
- uint32_t enable_input_buf = 0 ;
1847
-
1848
- #pragma unroll
1849
1787
for (int k_idx = 0 ; k_idx < experts_per_token; ++k_idx) {
1850
1788
int64_t const k_offset = original_row * experts_per_token + k_idx;
1851
1789
int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id;
1852
-
1853
- int64_t const expanded_original_row = original_row + k_idx * num_rows;
1854
- int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
1855
-
1856
1790
if (expert_id < 0 || expert_id >= num_experts_per_node) {
1857
1791
continue ;
1858
1792
}
1859
1793
1794
+ int64_t const expanded_original_row = original_row + k_idx * num_rows;
1795
+ int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
1796
+
1860
1797
int64_t expanded_rows = num_rows * experts_per_token;
1861
1798
if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
1862
1799
continue ;
1863
1800
}
1864
1801
1802
+ float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1 .f : scales[k_offset];
1803
+
1865
1804
auto const * expanded_permuted_rows_row_ptr =
1866
1805
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
1867
1806
1868
- // ComputeElem expert_result =
1869
- // arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1870
- static_assert (sizeof (expanded_permuted_rows_row_ptr[0 ]) == sizeof (int4 ));
1871
- input_val_buf[k_idx] = *reinterpret_cast <const int4 *>(expanded_permuted_rows_row_ptr + elem_index);
1872
- enable_input_buf |= 1 << k_idx;
1873
- }
1874
-
1875
- #pragma unroll
1876
- for (int k_idx = 0 ; k_idx < experts_per_token; ++k_idx) {
1877
- if (not (enable_input_buf & (1 << k_idx))) continue ;
1878
-
1879
- int64_t const k_offset = original_row * experts_per_token + k_idx;
1880
- float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1 .f : scales[k_offset];
1881
-
1882
- int4 input_val = input_val_buf[k_idx];
1883
- ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(*reinterpret_cast <const InputElem*>(&input_val));
1884
- // if (bias) {
1885
- // auto const* bias_ptr = bias_v + expert_id * num_elems_in_col;
1886
- // expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1887
- // }
1807
+ ComputeElem expert_result =
1808
+ arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1809
+ if (bias) {
1810
+ auto const * bias_ptr = bias_v + expert_id * num_elems_in_col;
1811
+ expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1812
+ }
1888
1813
1889
1814
thread_output = thread_output + row_scale * expert_result;
1890
1815
}
1891
1816
1892
- // OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1893
- // reduced_row_ptr_v[elem_index] = output_elem;
1894
- // TODO alignment issue?
1895
- __align__ (16 ) OutputElem output_elem_original = arrayConvert<ComputeElem, OutputElem>(thread_output);
1896
- int4 output_elem = *reinterpret_cast <int4 *>(&output_elem_original);
1897
- static_assert (sizeof (reduced_row_ptr_v[0 ]) == sizeof (int4 ));
1898
- *reinterpret_cast <int4 *>(reduced_row_ptr_v + elem_index) = output_elem;
1817
+ OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1818
+ reduced_row_ptr_v[elem_index] = output_elem;
1899
1819
}
1900
1820
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1901
1821
asm volatile (" griddepcontrol.launch_dependents;" );
1902
1822
#endif
1903
1823
}
1904
- }
1905
1824
1906
1825
// Final kernel to unpermute and scale
1907
1826
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
@@ -2211,13 +2130,10 @@ template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
2211
2130
__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result,
2212
2131
float const * fp8_quant, ScaleBiasType const * bias_ptr,
2213
2132
bool bias_is_broadcast, int64_t const * expert_first_token_offset,
2214
- int num_experts_per_node, int64_t inter_size_real_ ,
2133
+ int num_experts_per_node, int64_t inter_size ,
2215
2134
float const * fc2_act_global_scale, bool use_per_expert_act_scale,
2216
2135
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
2217
2136
ActivationParams activation_params) {
2218
- constexpr int inter_size = 2048 ;
2219
- if (inter_size != inter_size_real_) { asm (" trap;" ); }
2220
-
2221
2137
#ifdef ENABLE_FP4
2222
2138
constexpr bool IsNVFP4 =
2223
2139
std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2302,18 +2218,16 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
2302
2218
auto output_vec = reinterpret_cast <OutputElem*>(safe_inc_ptr (output, output_offset));
2303
2219
auto bias_ptr_vec = reinterpret_cast <BiasElem const *>(bias_ptr + bias_offset);
2304
2220
int64_t const start_offset = tid;
2305
- constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK;
2221
+ int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
2306
2222
assert (inter_size % ACTIVATION_ELEM_PER_THREAD == 0 );
2307
- constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2223
+ int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2308
2224
assert (gated_off % ACTIVATION_ELEM_PER_THREAD == 0 );
2309
2225
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
2310
2226
2311
2227
ActFn fn{};
2312
2228
fn.alpha = gate_alpha;
2313
2229
fn.beta = gate_beta;
2314
2230
fn.limit = gate_limit;
2315
-
2316
- #pragma unroll
2317
2231
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
2318
2232
auto fc1_value =
2319
2233
arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
0 commit comments