2121#include " tensorrt_llm/thop/thUtils.h"
2222#include < ATen/cuda/EmptyTensor.h>
2323#include < ATen/ops/index_select.h>
24+ #include < iostream>
2425
2526namespace torch_ext
2627{
@@ -44,6 +45,78 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
4445 bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
4546 torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
4647{
48+ std::cout << " Function: run_fp4_block_scale_moe_runner" << std::endl;
49+
50+ auto print_tensor = [](std::string name, torch::Tensor const & t)
51+ {
52+ std::cout << name << " : shape=[" ;
53+ for (auto s : t.sizes ())
54+ {
55+ std::cout << s << " ," ;
56+ }
57+ std::cout << " ], dtype=" << t.scalar_type () << std::endl;
58+ };
59+
60+ auto print_opt_tensor = [&](std::string name, auto const & t)
61+ {
62+ if (t.has_value ())
63+ {
64+ print_tensor (name, t.value ());
65+ }
66+ else
67+ {
68+ std::cout << name << " : None" << std::endl;
69+ }
70+ };
71+
72+ auto print_val = [](std::string name, auto const & v) { std::cout << name << " : " << v << std::endl; };
73+
74+ auto print_opt_val = [&](std::string name, auto const & v)
75+ {
76+ if (v.has_value ())
77+ {
78+ std::cout << name << " : " << v.value () << std::endl;
79+ }
80+ else
81+ {
82+ std::cout << name << " : None" << std::endl;
83+ }
84+ };
85+
86+ print_opt_tensor (" routing_logits" , routing_logits);
87+ print_opt_tensor (" routing_bias" , routing_bias);
88+ print_tensor (" hidden_states" , hidden_states);
89+ print_opt_tensor (" hidden_states_scale" , hidden_states_scale);
90+ print_tensor (" gemm1_weights" , gemm1_weights);
91+ print_tensor (" gemm1_weights_scale" , gemm1_weights_scale);
92+ print_opt_tensor (" gemm1_bias" , gemm1_bias);
93+ print_opt_tensor (" gemm1_alpha" , gemm1_alpha);
94+ print_opt_tensor (" gemm1_beta" , gemm1_beta);
95+ print_opt_tensor (" gemm1_clamp_limit" , gemm1_clamp_limit);
96+ print_tensor (" gemm2_weights" , gemm2_weights);
97+ print_tensor (" gemm2_weights_scale" , gemm2_weights_scale);
98+ print_opt_tensor (" gemm2_bias" , gemm2_bias);
99+ print_tensor (" output1_scales_scalar" , output1_scales_scalar);
100+ print_tensor (" output1_scales_gate_scalar" , output1_scales_gate_scalar);
101+ print_tensor (" output2_scales_scalar" , output2_scales_scalar);
102+
103+ print_val (" num_experts" , num_experts);
104+ print_val (" top_k" , top_k);
105+ print_opt_val (" n_group" , n_group);
106+ print_opt_val (" topk_group" , topk_group);
107+ print_val (" intermediate_size" , intermediate_size);
108+ print_val (" local_expert_offset" , local_expert_offset);
109+ print_val (" local_num_experts" , local_num_experts);
110+ print_opt_val (" routed_scaling_factor" , routed_scaling_factor);
111+ print_val (" tile_tokens_dim" , tile_tokens_dim);
112+ print_val (" routing_method_type" , routing_method_type);
113+ print_val (" do_finalize" , do_finalize);
114+ print_val (" dtype" , static_cast <int >(dtype));
115+ print_val (" moeConfigIndex" , moeConfigIndex);
116+ print_opt_tensor (" topk_weights" , topk_weights);
117+ print_opt_tensor (" topk_ids" , topk_ids);
118+ std::cout << " --------------------------------" << std::endl;
119+
47120 TORCH_CHECK (dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, " dtype can only be e4m3 or e2m1." );
48121 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP4 block scale MOE" );
49122 TORCH_CHECK (tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64
0 commit comments