@@ -34,10 +34,8 @@ struct RotaryParams {
3434 torch::Tensor cu_query_lens;
3535 bool interleaved;
3636 bool discrete;
37- bool dynamic_ntk;
37+ bool dynamic_ntk = false ;
3838 int max_query_len;
39-
40- RotaryParams () : position_ids(std::nullopt ), dynamic_ntk(false ) {}
4139};
4240
4341// Activation parameters
@@ -48,14 +46,8 @@ struct ActivationParams {
4846 std::optional<torch::Tensor> cusum_token_count;
4947 std::string act_mode;
5048 bool is_gated;
51- int start_expert_id;
52- int expert_size;
53-
54- ActivationParams ()
55- : bias(std::nullopt ),
56- cusum_token_count (std::nullopt ),
57- start_expert_id(0 ),
58- expert_size(0 ) {}
49+ int start_expert_id = 0 ;
50+ int expert_size = 0 ;
5951};
6052
6153// Reshape paged cache parameters
@@ -65,82 +57,46 @@ struct ReshapePagedCacheParams {
6557 torch::Tensor k_cache;
6658 torch::Tensor v_cache;
6759 torch::Tensor slot_mapping;
68- bool direction;
69-
70- ReshapePagedCacheParams () : direction(false ) {}
60+ bool direction = false ;
7161};
7262
73- // Prefill attention parameters
74- struct PrefillAttentionParams {
63+ // Attention parameters
64+ struct AttentionParams {
65+ // common parameters
7566 torch::Tensor query;
76- torch::Tensor key;
77- torch::Tensor value;
7867 torch::Tensor output;
7968 torch::Tensor output_lse;
80- torch::Tensor query_start_loc;
81- torch::Tensor seq_start_loc;
8269 std::optional<torch::Tensor> alibi_slope;
83- std::optional<torch::Tensor> attn_bias;
8470 std::optional<torch::Tensor> q_quant_scale;
85- std::optional<torch::Tensor> k_quant_scale;
86- std::optional<torch::Tensor> v_quant_scale;
8771 std::optional<torch::Tensor> out_quant_scale;
88- std::optional<torch::Tensor> block_tables ;
89- int max_query_len ;
72+ std::optional<torch::Tensor> block_table ;
73+ std::string compute_dtype ;
9074 int max_seq_len;
91- float scale;
92- bool is_causal;
9375 int window_size_left;
94- int window_size_right;
95- std::string compute_dtype;
96- bool return_lse;
97-
98- FlashAttentionParams ()
99- : alibi_slope(std::nullopt ),
100- attn_bias (std::nullopt ),
101- q_quant_scale(std::nullopt ),
102- k_quant_scale(std::nullopt ),
103- v_quant_scale(std::nullopt ),
104- out_quant_scale(std::nullopt ),
105- block_tables(std::nullopt ),
106- is_causal(true ),
107- window_size_right(-1 ),
108- return_lse(false ) {}
109- };
76+ int window_size_right = -1 ;
77+ float scale;
78+ bool return_lse = false ;
11079
111- // Decode attention parameters
112- struct DecodeAttentionParams {
113- torch::Tensor query;
114- torch::Tensor k_cache;
115- torch::Tensor output;
116- torch::Tensor block_table;
117- torch::Tensor seq_lens;
118- torch::Tensor v_cache;
119- torch::Tensor output_lse;
120- std::optional<torch::Tensor> q_quant_scale;
80+ // prefill parameters
81+ std::optional<torch::Tensor> key;
82+ std::optional<torch::Tensor> value;
83+ std::optional<torch::Tensor> query_start_loc;
84+ std::optional<torch::Tensor> seq_start_loc;
85+ std::optional<torch::Tensor> attn_bias;
86+ std::optional<torch::Tensor> k_quant_scale;
87+ std::optional<torch::Tensor> v_quant_scale;
88+ int max_query_len;
89+ bool is_causal = true ;
90+
91+ // decode parameters
92+ std::optional<torch::Tensor> k_cache;
93+ std::optional<torch::Tensor> v_cache;
94+ std::optional<torch::Tensor> block_table;
95+ std::optional<torch::Tensor> seq_lens;
12196 std::optional<torch::Tensor> k_cache_quant_scale;
12297 std::optional<torch::Tensor> v_cache_quant_scale;
123- std::optional<torch::Tensor> out_quant_scale;
124- std::optional<torch::Tensor> alibi_slope;
12598 std::optional<torch::Tensor> mask;
126- std::string compute_dtype;
127- int max_seq_len;
128- int window_size_left;
129- int window_size_right;
130- float scale;
131- bool return_lse;
132- int kv_cache_quant_bit_size;
133-
134- DecodeAttentionParams ()
135- : q_quant_scale(std::nullopt ),
136- k_cache_quant_scale (std::nullopt ),
137- v_cache_quant_scale(std::nullopt ),
138- out_quant_scale(std::nullopt ),
139- alibi_slope(std::nullopt ),
140- mask(std::nullopt ),
141- window_size_right(-1 ),
142- return_lse(false ),
143- kv_cache_quant_bit_size(-1 ) {}
99+ int kv_cache_quant_bit_size = -1 ;
144100};
145101
146102// Fused layer norm parameters
@@ -157,21 +113,9 @@ struct FusedLayerNormParams {
157113 std::optional<torch::Tensor> normed_out;
158114 std::string mode;
159115 double eps;
160- bool store_output_before_norm;
161- bool store_output_after_norm;
162- bool dynamic_quant;
163-
164- FusedLayerNormParams ()
165- : residual(std::nullopt ),
166- beta (std::nullopt ),
167- bias(std::nullopt ),
168- quant_scale(std::nullopt ),
169- residual_out(std::nullopt ),
170- smooth_quant_scale(std::nullopt ),
171- normed_out(std::nullopt ),
172- store_output_before_norm(false ),
173- store_output_after_norm(false ),
174- dynamic_quant(false ) {}
116+ bool store_output_before_norm = false ;
117+ bool store_output_after_norm = false ;
118+ bool dynamic_quant = false ;
175119};
176120
177121// Matmul parameters
@@ -182,8 +126,6 @@ struct MatmulParams {
182126 std::optional<torch::Tensor> c;
183127 double alpha;
184128 double beta;
185-
186- MatmulParams () : bias(std::nullopt ), c(std::nullopt ) {}
187129};
188130
189131// Fused MoE parameters
@@ -203,39 +145,17 @@ struct FusedMoEParams {
203145 bool renormalize;
204146 bool gated;
205147 std::string act_mode;
206- std::string scoring_func;
207- int start_expert_id;
208- int block_n;
209- bool avg_moe;
148+ std::string scoring_func = " softmax " ;
149+ int start_expert_id = 0 ;
150+ int block_n = 0 ;
151+ bool avg_moe = false ;
210152 std::optional<torch::Tensor> class_reduce_weight;
211153 std::optional<torch::Tensor> class_expert_id;
212154 std::optional<std::vector<bool >> w1_quant_flag;
213155 std::optional<std::vector<bool >> w2_quant_flag;
214- int world_size;
215- int shared_expert_num;
216- std::string parallel_mode;
217-
218- FusedMoEParams ()
219- : bias1(std::nullopt ),
220- bias2 (std::nullopt ),
221- residual(std::nullopt ),
222- input_smooth(std::nullopt ),
223- act_smooth(std::nullopt ),
224- w1_scale(std::nullopt ),
225- w2_scale(std::nullopt ),
226- w1_quant_flag(std::nullopt ),
227- w2_quant_flag(std::nullopt ),
228- scoring_func(" softmax" ),
229- start_expert_id(0 ),
230- block_n(0 ),
231- avg_moe(false ),
232- class_reduce_weight(std::nullopt ),
233- class_expert_id(std::nullopt ),
234- w1_quant_flag(std::nullopt ),
235- w2_quant_flag(std::nullopt ),
236- world_size(0 ),
237- shared_expert_num(0 ),
238- parallel_mode(" ep" ) {}
156+ int world_size = 0 ;
157+ int shared_expert_num = 0 ;
158+ std::string parallel_mode = " ep" ;
239159};
240160} // namespace kernel
241161} // namespace xllm
0 commit comments