@@ -45,78 +45,6 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
4545 bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
4646 torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
4747{
48- std::cout << " Function: run_fp4_block_scale_moe_runner" << std::endl;
49-
50- auto print_tensor = [](std::string name, torch::Tensor const & t)
51- {
52- std::cout << name << " : shape=[" ;
53- for (auto s : t.sizes ())
54- {
55- std::cout << s << " ," ;
56- }
57- std::cout << " ], dtype=" << t.scalar_type () << std::endl;
58- };
59-
60- auto print_opt_tensor = [&](std::string name, auto const & t)
61- {
62- if (t.has_value ())
63- {
64- print_tensor (name, t.value ());
65- }
66- else
67- {
68- std::cout << name << " : None" << std::endl;
69- }
70- };
71-
72- auto print_val = [](std::string name, auto const & v) { std::cout << name << " : " << v << std::endl; };
73-
74- auto print_opt_val = [&](std::string name, auto const & v)
75- {
76- if (v.has_value ())
77- {
78- std::cout << name << " : " << v.value () << std::endl;
79- }
80- else
81- {
82- std::cout << name << " : None" << std::endl;
83- }
84- };
85-
86- print_opt_tensor (" routing_logits" , routing_logits);
87- print_opt_tensor (" routing_bias" , routing_bias);
88- print_tensor (" hidden_states" , hidden_states);
89- print_opt_tensor (" hidden_states_scale" , hidden_states_scale);
90- print_tensor (" gemm1_weights" , gemm1_weights);
91- print_tensor (" gemm1_weights_scale" , gemm1_weights_scale);
92- print_opt_tensor (" gemm1_bias" , gemm1_bias);
93- print_opt_tensor (" gemm1_alpha" , gemm1_alpha);
94- print_opt_tensor (" gemm1_beta" , gemm1_beta);
95- print_opt_tensor (" gemm1_clamp_limit" , gemm1_clamp_limit);
96- print_tensor (" gemm2_weights" , gemm2_weights);
97- print_tensor (" gemm2_weights_scale" , gemm2_weights_scale);
98- print_opt_tensor (" gemm2_bias" , gemm2_bias);
99- print_tensor (" output1_scales_scalar" , output1_scales_scalar);
100- print_tensor (" output1_scales_gate_scalar" , output1_scales_gate_scalar);
101- print_tensor (" output2_scales_scalar" , output2_scales_scalar);
102-
103- print_val (" num_experts" , num_experts);
104- print_val (" top_k" , top_k);
105- print_opt_val (" n_group" , n_group);
106- print_opt_val (" topk_group" , topk_group);
107- print_val (" intermediate_size" , intermediate_size);
108- print_val (" local_expert_offset" , local_expert_offset);
109- print_val (" local_num_experts" , local_num_experts);
110- print_opt_val (" routed_scaling_factor" , routed_scaling_factor);
111- print_val (" tile_tokens_dim" , tile_tokens_dim);
112- print_val (" routing_method_type" , routing_method_type);
113- print_val (" do_finalize" , do_finalize);
114- print_val (" dtype" , static_cast <int >(dtype));
115- print_val (" moeConfigIndex" , moeConfigIndex);
116- print_opt_tensor (" topk_weights" , topk_weights);
117- print_opt_tensor (" topk_ids" , topk_ids);
118- std::cout << " --------------------------------" << std::endl;
119-
12048 TORCH_CHECK (dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, " dtype can only be e4m3 or e2m1." );
12149 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP4 block scale MOE" );
12250 TORCH_CHECK (tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64
@@ -526,7 +454,7 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
526454public:
527455 explicit FP4BlockScaleMoeRunner ()
528456 // Update this as new cubins come in
529- : mSupportedTileN{8 , 16 , 32 , 64 , 128 , 256 }
457+ : mSupportedTileN{8 , 16 , 32 , 64 , 128 }
530458 {
531459 for (int tileN : mSupportedTileN )
532460 {
0 commit comments