@@ -142,16 +142,12 @@ class FluxSingleAttentionImpl : public torch::nn::Module {
142142 auto head_dim = model_args.head_dim ();
143143 auto query_dim = heads_ * head_dim;
144144 auto out_dim = query_dim;
145- to_q_ = register_module (" to_q" ,
146- DiTLinear (query_dim, out_dim, true /* has_bias*/ ));
147- to_k_ = register_module (" to_k" ,
148- DiTLinear (query_dim, out_dim, true /* has_bias*/ ));
149- to_v_ = register_module (" to_v" ,
150- DiTLinear (query_dim, out_dim, true /* has_bias*/ ));
151145
152- to_q_->to (options_);
153- to_k_->to (options_);
154- to_v_->to (options_);
146+ fused_qkv_weight_ = register_parameter (
147+ " fused_qkv_weight" , torch::empty ({3 * query_dim, out_dim}, options_));
148+
149+ fused_qkv_bias_ = register_parameter (" fused_qkv_bias" ,
150+ torch::empty ({3 * out_dim}, options_));
155151
156152 norm_q_ = register_module (" norm_q" ,
157153 DiTRMSNorm (head_dim,
@@ -170,19 +166,15 @@ class FluxSingleAttentionImpl : public torch::nn::Module {
170166 torch::Tensor forward (const torch::Tensor& hidden_states,
171167 const torch::Tensor& image_rotary_emb) {
172168 int64_t batch_size, channel, height, width;
169+ batch_size = hidden_states.size (0 );
173170
174- // Reshape 4D input to [B, seq_len, C]
175- torch::Tensor hidden_states_ =
176- hidden_states; // Use copy to avoid modifying input
177- batch_size = hidden_states_.size (0 );
178-
179- // Self-attention: use hidden_states as context
180- torch::Tensor context = hidden_states_;
171+ auto qkv = torch::nn::functional::linear (
172+ hidden_states, fused_qkv_weight_, fused_qkv_bias_);
173+ auto chunks = qkv.chunk (3 , -1 );
181174
182- // Compute QKV projections
183- torch::Tensor query = to_q_->forward (hidden_states_);
184- torch::Tensor key = to_k_->forward (context);
185- torch::Tensor value = to_v_->forward (context);
175+ torch::Tensor query = chunks[0 ];
176+ torch::Tensor key = chunks[1 ];
177+ torch::Tensor value = chunks[2 ];
186178
187179 // Reshape for multi-head attention
188180 int64_t inner_dim = key.size (-1 );
@@ -210,26 +202,53 @@ class FluxSingleAttentionImpl : public torch::nn::Module {
210202 norm_q_->load_state_dict (state_dict.get_dict_with_prefix (" norm_q." ));
211203 // norm_k
212204 norm_k_->load_state_dict (state_dict.get_dict_with_prefix (" norm_k." ));
213- // to_q
214- to_q_->load_state_dict (state_dict.get_dict_with_prefix (" to_q." ));
215- // to_k
216- to_k_->load_state_dict (state_dict.get_dict_with_prefix (" to_k." ));
217- // to_v
218- to_v_->load_state_dict (state_dict.get_dict_with_prefix (" to_v." ));
205+
206+ auto to_q_weight = state_dict.get_tensor (" to_q.weight" );
207+ auto to_q_bias = state_dict.get_tensor (" to_q.bias" );
208+ auto to_k_weight = state_dict.get_tensor (" to_k.weight" );
209+ auto to_k_bias = state_dict.get_tensor (" to_k.bias" );
210+ auto to_v_weight = state_dict.get_tensor (" to_v.weight" );
211+ auto to_v_bias = state_dict.get_tensor (" to_v.bias" );
212+
213+ if (to_q_weight.defined () && to_k_weight.defined () &&
214+ to_v_weight.defined ()) {
215+ auto fused_qkv_weight =
216+ torch::cat ({to_q_weight, to_k_weight, to_v_weight}, 0 ).contiguous ();
217+ DCHECK_EQ (fused_qkv_weight_.sizes (), fused_qkv_weight.sizes ())
218+ << " fused_qkv_weight_ size mismatch: expected "
219+ << fused_qkv_weight_.sizes () << " but got "
220+ << fused_qkv_weight.sizes ();
221+ fused_qkv_weight_.data ().copy_ (fused_qkv_weight.to (
222+ fused_qkv_weight_.device (), fused_qkv_weight_.dtype ()));
223+ is_qkv_weight_loaded_ = true ;
224+ }
225+
226+ if (to_q_bias.defined () && to_k_bias.defined () && to_v_bias.defined ()) {
227+ auto fused_qkv_bias =
228+ torch::cat ({to_q_bias, to_k_bias, to_v_bias}, 0 ).contiguous ();
229+ DCHECK_EQ (fused_qkv_bias_.sizes (), fused_qkv_bias.sizes ())
230+ << " fused_qkv_bias_ size mismatch: expected "
231+ << fused_qkv_bias_.sizes () << " but got " << fused_qkv_bias.sizes ();
232+ fused_qkv_bias_.data ().copy_ (
233+ fused_qkv_bias.to (fused_qkv_bias_.device (), fused_qkv_bias_.dtype ()));
234+ is_qkv_bias_loaded_ = true ;
235+ }
219236 }
220237
221238 void verify_loaded_weights (const std::string& prefix) const {
239+ CHECK (is_qkv_weight_loaded_)
240+ << " weight is not loaded for " << prefix + " qkv_proj.weight" ;
241+ CHECK (is_qkv_bias_loaded_)
242+ << " bias is not loaded for " << prefix + " qkv_proj.bias" ;
222243 norm_q_->verify_loaded_weights (prefix + " norm_q." );
223244 norm_k_->verify_loaded_weights (prefix + " norm_k." );
224- to_q_->verify_loaded_weights (prefix + " to_q." );
225- to_k_->verify_loaded_weights (prefix + " to_k." );
226- to_v_->verify_loaded_weights (prefix + " to_v." );
227245 }
228246
229247 private:
230- DiTLinear to_q_{nullptr };
231- DiTLinear to_k_{nullptr };
232- DiTLinear to_v_{nullptr };
248+ bool is_qkv_weight_loaded_{false };
249+ bool is_qkv_bias_loaded_{false };
250+ torch::Tensor fused_qkv_weight_{};
251+ torch::Tensor fused_qkv_bias_{};
233252 int64_t heads_;
234253 DiTRMSNorm norm_q_{nullptr };
235254 DiTRMSNorm norm_k_{nullptr };
@@ -248,29 +267,24 @@ class FluxAttentionImpl : public torch::nn::Module {
248267 auto out_dim = query_dim;
249268 auto added_kv_proj_dim = query_dim;
250269
251- to_q_ = register_module (" to_q" , DiTLinear (query_dim, out_dim, true ));
252- to_k_ = register_module (" to_k" , DiTLinear (query_dim, out_dim, true ));
253- to_v_ = register_module (" to_v" , DiTLinear (query_dim, out_dim, true ));
254- add_q_proj_ = register_module (" add_q_proj" ,
255- DiTLinear (added_kv_proj_dim, out_dim, true ));
256-
257- add_k_proj_ = register_module (" add_k_proj" ,
258- DiTLinear (added_kv_proj_dim, out_dim, true ));
259-
260- add_v_proj_ = register_module (" add_v_proj" ,
261- DiTLinear (added_kv_proj_dim, out_dim, true ));
262-
263270 to_out_ = register_module (" to_out" , DiTLinear (out_dim, query_dim, true ));
264271
265272 to_add_out_ = register_module (" to_add_out" ,
266273 DiTLinear (out_dim, added_kv_proj_dim, true ));
267274
268- to_q_->to (options_);
269- to_k_->to (options_);
270- to_v_->to (options_);
271- add_q_proj_->to (options_);
272- add_k_proj_->to (options_);
273- add_v_proj_->to (options_);
275+ fused_qkv_weight_ = register_parameter (
276+ " fused_qkv_weight" , torch::empty ({3 * query_dim, out_dim}, options_));
277+
278+ fused_qkv_bias_ = register_parameter (" fused_qkv_bias" ,
279+ torch::empty ({3 * out_dim}, options_));
280+
281+ fused_add_qkv_weight_ = register_parameter (
282+ " fused_add_qkv_weight" ,
283+ torch::empty ({3 * added_kv_proj_dim, out_dim}, options_));
284+
285+ fused_add_qkv_bias_ = register_parameter (
286+ " fused_add_qkv_bias" , torch::empty ({3 * out_dim}, options_));
287+
274288 to_out_->to (options_);
275289 to_add_out_->to (options_);
276290
@@ -330,9 +344,15 @@ class FluxAttentionImpl : public torch::nn::Module {
330344 .transpose (1 , 2 );
331345 }
332346 int64_t batch_size = encoder_hidden_states_reshaped.size (0 );
333- torch::Tensor query = to_q_->forward (hidden_states_reshaped);
334- torch::Tensor key = to_k_->forward (hidden_states_reshaped);
335- torch::Tensor value = to_v_->forward (hidden_states_reshaped);
347+
348+ auto qkv = torch::nn::functional::linear (
349+ hidden_states_reshaped, fused_qkv_weight_, fused_qkv_bias_);
350+
351+ auto chunks = qkv.chunk (3 , -1 );
352+ torch::Tensor query = chunks[0 ];
353+ torch::Tensor key = chunks[1 ];
354+ torch::Tensor value = chunks[2 ];
355+
336356 int64_t inner_dim = key.size (-1 );
337357 int64_t attn_heads = heads_;
338358
@@ -342,13 +362,17 @@ class FluxAttentionImpl : public torch::nn::Module {
342362 value = value.view ({batch_size, -1 , attn_heads, head_dim}).transpose (1 , 2 );
343363 if (norm_q_) query = norm_q_->forward (query);
344364 if (norm_k_) key = norm_k_->forward (key);
345- // encoder hidden states
346- torch::Tensor encoder_hidden_states_query_proj =
347- add_q_proj_->forward (encoder_hidden_states_reshaped);
348- torch::Tensor encoder_hidden_states_key_proj =
349- add_k_proj_->forward (encoder_hidden_states_reshaped);
350- torch::Tensor encoder_hidden_states_value_proj =
351- add_v_proj_->forward (encoder_hidden_states_reshaped);
365+
366+ auto encoder_qkv =
367+ torch::nn::functional::linear (encoder_hidden_states_reshaped,
368+ fused_add_qkv_weight_,
369+ fused_add_qkv_bias_);
370+
371+ auto encoder_chunks = encoder_qkv.chunk (3 , -1 );
372+ torch::Tensor encoder_hidden_states_query_proj = encoder_chunks[0 ];
373+ torch::Tensor encoder_hidden_states_key_proj = encoder_chunks[1 ];
374+ torch::Tensor encoder_hidden_states_value_proj = encoder_chunks[2 ];
375+
352376 encoder_hidden_states_query_proj =
353377 encoder_hidden_states_query_proj
354378 .view ({batch_size, -1 , attn_heads, head_dim})
@@ -396,12 +420,6 @@ class FluxAttentionImpl : public torch::nn::Module {
396420 }
397421
398422 void load_state_dict (const StateDict& state_dict) {
399- // to_q
400- to_q_->load_state_dict (state_dict.get_dict_with_prefix (" to_q." ));
401- // to_k
402- to_k_->load_state_dict (state_dict.get_dict_with_prefix (" to_k." ));
403- // to_v
404- to_v_->load_state_dict (state_dict.get_dict_with_prefix (" to_v." ));
405423 // to_out
406424 to_out_->load_state_dict (state_dict.get_dict_with_prefix (" to_out.0." ));
407425 // to_add_out
@@ -417,39 +435,98 @@ class FluxAttentionImpl : public torch::nn::Module {
417435 // norm_added_k
418436 norm_added_k_->load_state_dict (
419437 state_dict.get_dict_with_prefix (" norm_added_k." ));
420- // add_q_proj
421- add_q_proj_->load_state_dict (
422- state_dict.get_dict_with_prefix (" add_q_proj." ));
423- // add_k_proj
424- add_k_proj_->load_state_dict (
425- state_dict.get_dict_with_prefix (" add_k_proj." ));
426- // add_v_proj
427- add_v_proj_->load_state_dict (
428- state_dict.get_dict_with_prefix (" add_v_proj." ));
438+
439+ auto to_q_weight = state_dict.get_tensor (" to_q.weight" );
440+ auto to_q_bias = state_dict.get_tensor (" to_q.bias" );
441+ auto to_k_weight = state_dict.get_tensor (" to_k.weight" );
442+ auto to_k_bias = state_dict.get_tensor (" to_k.bias" );
443+ auto to_v_weight = state_dict.get_tensor (" to_v.weight" );
444+ auto to_v_bias = state_dict.get_tensor (" to_v.bias" );
445+
446+ if (to_q_weight.defined () && to_k_weight.defined () &&
447+ to_v_weight.defined ()) {
448+ auto fused_qkv_weight =
449+ torch::cat ({to_q_weight, to_k_weight, to_v_weight}, 0 ).contiguous ();
450+ DCHECK_EQ (fused_qkv_weight_.sizes (), fused_qkv_weight.sizes ())
451+ << " fused_qkv_weight_ size mismatch: expected "
452+ << fused_qkv_weight_.sizes () << " but got "
453+ << fused_qkv_weight.sizes ();
454+ fused_qkv_weight_.data ().copy_ (fused_qkv_weight.to (
455+ fused_qkv_weight_.device (), fused_qkv_weight_.dtype ()));
456+ is_qkv_weight_loaded_ = true ;
457+ }
458+
459+ if (to_q_bias.defined () && to_k_bias.defined () && to_v_bias.defined ()) {
460+ auto fused_qkv_bias =
461+ torch::cat ({to_q_bias, to_k_bias, to_v_bias}, 0 ).contiguous ();
462+ DCHECK_EQ (fused_qkv_bias_.sizes (), fused_qkv_bias.sizes ())
463+ << " fused_qkv_bias_ size mismatch: expected "
464+ << fused_qkv_bias_.sizes () << " but got " << fused_qkv_bias.sizes ();
465+ fused_qkv_bias_.data ().copy_ (
466+ fused_qkv_bias.to (fused_qkv_bias_.device (), fused_qkv_bias_.dtype ()));
467+ is_qkv_bias_loaded_ = true ;
468+ }
469+
470+ auto add_q_weight = state_dict.get_tensor (" add_q_proj.weight" );
471+ auto add_q_bias = state_dict.get_tensor (" add_q_proj.bias" );
472+ auto add_k_weight = state_dict.get_tensor (" add_k_proj.weight" );
473+ auto add_k_bias = state_dict.get_tensor (" add_k_proj.bias" );
474+ auto add_v_weight = state_dict.get_tensor (" add_v_proj.weight" );
475+ auto add_v_bias = state_dict.get_tensor (" add_v_proj.bias" );
476+
477+ if (add_q_weight.defined () && add_k_weight.defined () &&
478+ add_v_weight.defined ()) {
479+ auto fused_add_qkv_weight =
480+ torch::cat ({add_q_weight, add_k_weight, add_v_weight}, 0 )
481+ .contiguous ();
482+ DCHECK_EQ (fused_add_qkv_weight_.sizes (), fused_add_qkv_weight.sizes ())
483+ << " fused_add_qkv_weight_ size mismatch: expected "
484+ << fused_add_qkv_weight_.sizes () << " but got "
485+ << fused_add_qkv_weight.sizes ();
486+ fused_add_qkv_weight_.data ().copy_ (fused_add_qkv_weight.to (
487+ fused_add_qkv_weight_.device (), fused_add_qkv_weight_.dtype ()));
488+ is_add_qkv_weight_loaded_ = true ;
489+ }
490+
491+ if (add_q_bias.defined () && add_k_bias.defined () && add_v_bias.defined ()) {
492+ auto fused_add_qkv_bias =
493+ torch::cat ({add_q_bias, add_k_bias, add_v_bias}, 0 ).contiguous ();
494+ DCHECK_EQ (fused_add_qkv_bias_.sizes (), fused_add_qkv_bias.sizes ())
495+ << " fused_add_qkv_bias_ size mismatch: expected "
496+ << fused_add_qkv_bias_.sizes () << " but got "
497+ << fused_add_qkv_bias.sizes ();
498+ fused_add_qkv_bias_.data ().copy_ (fused_add_qkv_bias.to (
499+ fused_add_qkv_bias_.device (), fused_add_qkv_bias_.dtype ()));
500+ is_add_qkv_bias_loaded_ = true ;
501+ }
429502 }
430503
431504 void verify_loaded_weights (const std::string& prefix) const {
505+ CHECK (is_qkv_weight_loaded_)
506+ << " weight is not loaded for " << prefix + " qkv_proj.weight" ;
507+ CHECK (is_qkv_bias_loaded_)
508+ << " bias is not loaded for " << prefix + " qkv_proj.bias" ;
509+ CHECK (is_add_qkv_weight_loaded_)
510+ << " weight is not loaded for " << prefix + " add_qkv_proj.weight" ;
511+ CHECK (is_add_qkv_bias_loaded_)
512+ << " bias is not loaded for " << prefix + " add_qkv_proj.bias" ;
432513 norm_q_->verify_loaded_weights (prefix + " norm_q." );
433514 norm_k_->verify_loaded_weights (prefix + " norm_k." );
434515 norm_added_q_->verify_loaded_weights (prefix + " norm_added_q." );
435516 norm_added_k_->verify_loaded_weights (prefix + " norm_added_k." );
436- to_q_->verify_loaded_weights (prefix + " to_q." );
437- to_k_->verify_loaded_weights (prefix + " to_k." );
438- to_v_->verify_loaded_weights (prefix + " to_v." );
439517 to_out_->verify_loaded_weights (prefix + " to_out.0." );
440518 to_add_out_->verify_loaded_weights (prefix + " to_add_out." );
441- add_q_proj_->verify_loaded_weights (prefix + " add_q_proj." );
442- add_k_proj_->verify_loaded_weights (prefix + " add_k_proj." );
443- add_v_proj_->verify_loaded_weights (prefix + " add_v_proj." );
444519 }
445520
446521 private:
447- DiTLinear to_q_{nullptr };
448- DiTLinear to_k_{nullptr };
449- DiTLinear to_v_{nullptr };
450- DiTLinear add_q_proj_{nullptr };
451- DiTLinear add_k_proj_{nullptr };
452- DiTLinear add_v_proj_{nullptr };
522+ bool is_qkv_weight_loaded_{false };
523+ bool is_qkv_bias_loaded_{false };
524+ bool is_add_qkv_weight_loaded_{false };
525+ bool is_add_qkv_bias_loaded_{false };
526+ torch::Tensor fused_qkv_weight_{};
527+ torch::Tensor fused_qkv_bias_{};
528+ torch::Tensor fused_add_qkv_weight_{};
529+ torch::Tensor fused_add_qkv_bias_{};
453530 DiTLinear to_out_{nullptr };
454531 DiTLinear to_add_out_{nullptr };
455532
@@ -1119,8 +1196,8 @@ class FluxTransformerBlockImpl : public torch::nn::Module {
11191196 FluxAttention attn_{nullptr };
11201197 torch::nn::LayerNorm norm2_{nullptr };
11211198 FeedForward ff_{nullptr };
1122- torch::nn::LayerNorm norm2_context_{nullptr };
11231199 FeedForward ff_context_{nullptr };
1200+ torch::nn::LayerNorm norm2_context_{nullptr };
11241201 torch::TensorOptions options_;
11251202};
11261203TORCH_MODULE (FluxTransformerBlock);
0 commit comments