Skip to content

Commit 1028207

Browse files
committed
refactor: unify attention parameter structures and update function signatures for clarity.
1 parent b255e9a commit 1028207

File tree

4 files changed

+186
-271
lines changed

4 files changed

+186
-271
lines changed

xllm/core/kernels/param.h

Lines changed: 39 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ struct RotaryParams {
3434
torch::Tensor cu_query_lens;
3535
bool interleaved;
3636
bool discrete;
37-
bool dynamic_ntk;
37+
bool dynamic_ntk = false;
3838
int max_query_len;
39-
40-
RotaryParams() : position_ids(std::nullopt), dynamic_ntk(false) {}
4139
};
4240

4341
// Activation parameters
@@ -48,14 +46,8 @@ struct ActivationParams {
4846
std::optional<torch::Tensor> cusum_token_count;
4947
std::string act_mode;
5048
bool is_gated;
51-
int start_expert_id;
52-
int expert_size;
53-
54-
ActivationParams()
55-
: bias(std::nullopt),
56-
cusum_token_count(std::nullopt),
57-
start_expert_id(0),
58-
expert_size(0) {}
49+
int start_expert_id = 0;
50+
int expert_size = 0;
5951
};
6052

6153
// Reshape paged cache parameters
@@ -65,82 +57,46 @@ struct ReshapePagedCacheParams {
6557
torch::Tensor k_cache;
6658
torch::Tensor v_cache;
6759
torch::Tensor slot_mapping;
68-
bool direction;
69-
70-
ReshapePagedCacheParams() : direction(false) {}
60+
bool direction = false;
7161
};
7262

73-
// Prefill attention parameters
74-
struct PrefillAttentionParams {
63+
// Attention parameters
64+
struct AttentionParams {
65+
// common parameters
7566
torch::Tensor query;
76-
torch::Tensor key;
77-
torch::Tensor value;
7867
torch::Tensor output;
7968
torch::Tensor output_lse;
80-
torch::Tensor query_start_loc;
81-
torch::Tensor seq_start_loc;
8269
std::optional<torch::Tensor> alibi_slope;
83-
std::optional<torch::Tensor> attn_bias;
8470
std::optional<torch::Tensor> q_quant_scale;
85-
std::optional<torch::Tensor> k_quant_scale;
86-
std::optional<torch::Tensor> v_quant_scale;
8771
std::optional<torch::Tensor> out_quant_scale;
88-
std::optional<torch::Tensor> block_tables;
89-
int max_query_len;
72+
std::optional<torch::Tensor> block_table;
73+
std::string compute_dtype;
9074
int max_seq_len;
91-
float scale;
92-
bool is_causal;
9375
int window_size_left;
94-
int window_size_right;
95-
std::string compute_dtype;
96-
bool return_lse;
97-
98-
FlashAttentionParams()
99-
: alibi_slope(std::nullopt),
100-
attn_bias(std::nullopt),
101-
q_quant_scale(std::nullopt),
102-
k_quant_scale(std::nullopt),
103-
v_quant_scale(std::nullopt),
104-
out_quant_scale(std::nullopt),
105-
block_tables(std::nullopt),
106-
is_causal(true),
107-
window_size_right(-1),
108-
return_lse(false) {}
109-
};
76+
int window_size_right = -1;
77+
float scale;
78+
bool return_lse = false;
11079

111-
// Decode attention parameters
112-
struct DecodeAttentionParams {
113-
torch::Tensor query;
114-
torch::Tensor k_cache;
115-
torch::Tensor output;
116-
torch::Tensor block_table;
117-
torch::Tensor seq_lens;
118-
torch::Tensor v_cache;
119-
torch::Tensor output_lse;
120-
std::optional<torch::Tensor> q_quant_scale;
80+
// prefill parameters
81+
std::optional<torch::Tensor> key;
82+
std::optional<torch::Tensor> value;
83+
std::optional<torch::Tensor> query_start_loc;
84+
std::optional<torch::Tensor> seq_start_loc;
85+
std::optional<torch::Tensor> attn_bias;
86+
std::optional<torch::Tensor> k_quant_scale;
87+
std::optional<torch::Tensor> v_quant_scale;
88+
int max_query_len;
89+
bool is_causal = true;
90+
91+
// decode parameters
92+
std::optional<torch::Tensor> k_cache;
93+
std::optional<torch::Tensor> v_cache;
94+
std::optional<torch::Tensor> block_table;
95+
std::optional<torch::Tensor> seq_lens;
12196
std::optional<torch::Tensor> k_cache_quant_scale;
12297
std::optional<torch::Tensor> v_cache_quant_scale;
123-
std::optional<torch::Tensor> out_quant_scale;
124-
std::optional<torch::Tensor> alibi_slope;
12598
std::optional<torch::Tensor> mask;
126-
std::string compute_dtype;
127-
int max_seq_len;
128-
int window_size_left;
129-
int window_size_right;
130-
float scale;
131-
bool return_lse;
132-
int kv_cache_quant_bit_size;
133-
134-
DecodeAttentionParams()
135-
: q_quant_scale(std::nullopt),
136-
k_cache_quant_scale(std::nullopt),
137-
v_cache_quant_scale(std::nullopt),
138-
out_quant_scale(std::nullopt),
139-
alibi_slope(std::nullopt),
140-
mask(std::nullopt),
141-
window_size_right(-1),
142-
return_lse(false),
143-
kv_cache_quant_bit_size(-1) {}
99+
int kv_cache_quant_bit_size = -1;
144100
};
145101

146102
// Fused layer norm parameters
@@ -157,21 +113,9 @@ struct FusedLayerNormParams {
157113
std::optional<torch::Tensor> normed_out;
158114
std::string mode;
159115
double eps;
160-
bool store_output_before_norm;
161-
bool store_output_after_norm;
162-
bool dynamic_quant;
163-
164-
FusedLayerNormParams()
165-
: residual(std::nullopt),
166-
beta(std::nullopt),
167-
bias(std::nullopt),
168-
quant_scale(std::nullopt),
169-
residual_out(std::nullopt),
170-
smooth_quant_scale(std::nullopt),
171-
normed_out(std::nullopt),
172-
store_output_before_norm(false),
173-
store_output_after_norm(false),
174-
dynamic_quant(false) {}
116+
bool store_output_before_norm = false;
117+
bool store_output_after_norm = false;
118+
bool dynamic_quant = false;
175119
};
176120

177121
// Matmul parameters
@@ -182,8 +126,6 @@ struct MatmulParams {
182126
std::optional<torch::Tensor> c;
183127
double alpha;
184128
double beta;
185-
186-
MatmulParams() : bias(std::nullopt), c(std::nullopt) {}
187129
};
188130

189131
// Fused MoE parameters
@@ -203,39 +145,17 @@ struct FusedMoEParams {
203145
bool renormalize;
204146
bool gated;
205147
std::string act_mode;
206-
std::string scoring_func;
207-
int start_expert_id;
208-
int block_n;
209-
bool avg_moe;
148+
std::string scoring_func = "softmax";
149+
int start_expert_id = 0;
150+
int block_n = 0;
151+
bool avg_moe = false;
210152
std::optional<torch::Tensor> class_reduce_weight;
211153
std::optional<torch::Tensor> class_expert_id;
212154
std::optional<std::vector<bool>> w1_quant_flag;
213155
std::optional<std::vector<bool>> w2_quant_flag;
214-
int world_size;
215-
int shared_expert_num;
216-
std::string parallel_mode;
217-
218-
FusedMoEParams()
219-
: bias1(std::nullopt),
220-
bias2(std::nullopt),
221-
residual(std::nullopt),
222-
input_smooth(std::nullopt),
223-
act_smooth(std::nullopt),
224-
w1_scale(std::nullopt),
225-
w2_scale(std::nullopt),
226-
w1_quant_flag(std::nullopt),
227-
w2_quant_flag(std::nullopt),
228-
scoring_func("softmax"),
229-
start_expert_id(0),
230-
block_n(0),
231-
avg_moe(false),
232-
class_reduce_weight(std::nullopt),
233-
class_expert_id(std::nullopt),
234-
w1_quant_flag(std::nullopt),
235-
w2_quant_flag(std::nullopt),
236-
world_size(0),
237-
shared_expert_num(0),
238-
parallel_mode("ep") {}
156+
int world_size = 0;
157+
int shared_expert_num = 0;
158+
std::string parallel_mode = "ep";
239159
};
240160
} // namespace kernel
241161
} // namespace xllm

0 commit comments

Comments
 (0)