@@ -32,29 +32,29 @@ namespace dlib
3232 namespace canonical_transformer
3333 {
3434
35- template <long seq_len, long d_model, long num_heads, typename SUBNET>
36- using query = reshape_to<num_heads, seq_len , d_model / num_heads,
35+ template <long d_model, long num_heads, typename SUBNET>
36+ using query = reshape_to<num_heads, - 1 , d_model / num_heads,
3737 linear_no_bias<d_model, SUBNET>>;
3838
39- template <long seq_len, long d_model, long num_heads, typename SUBNET>
40- using key = reshape_to<num_heads, seq_len , d_model / num_heads,
39+ template <long d_model, long num_heads, typename SUBNET>
40+ using key = reshape_to<num_heads, - 1 , d_model / num_heads,
4141 linear_no_bias<d_model, SUBNET>>;
4242
43- template <long seq_len, long d_model, long num_heads, typename SUBNET>
44- using value = reshape_to<num_heads, seq_len , d_model / num_heads,
43+ template <long d_model, long num_heads, typename SUBNET>
44+ using value = reshape_to<num_heads, - 1 , d_model / num_heads,
4545 linear_no_bias<d_model, SUBNET>>;
4646
4747 template <template <typename > class ACT , template <typename > class DO ,
48- long seq_len, long d_model, long num_heads, typename SUBNET>
48+ long d_model, long num_heads, typename SUBNET>
4949 using multihead_attention =
50- DO<linear_no_bias<d_model, reshape_to<1 , seq_len , d_model,
50+ DO<linear_no_bias<d_model, reshape_to<1 , - 1 , d_model,
5151 multm_prev3<softmaxm<tril_mask<
5252 scale_weights<d_model / num_heads,
5353 multm_prev4<
54- rope<query<seq_len, d_model, num_heads, skip1<
54+ rope<query<d_model, num_heads, skip1<
5555 tag4<transpose<
56- rope<key<seq_len, d_model, num_heads, skip2<
57- tag3<value<seq_len, d_model, num_heads,
56+ rope<key<d_model, num_heads, skip2<
57+ tag3<value<d_model, num_heads,
5858 tag2<SUBNET>>>>>>>>>>>>>>>>>>>;
5959
6060 template <template <typename > class ACT , template <typename > class DO ,
@@ -68,29 +68,29 @@ namespace dlib
6868 tag7<silu<linear<(d_model * 2 ) / 7 , tag6<SUBNET>>>>>>>>>;
6969
7070 template <template <typename > class ACT , template <typename > class DO ,
71- long seq_len, long d_model, long num_heads, typename SUBNET>
71+ long d_model, long num_heads, typename SUBNET>
7272 using transformer_block =
7373 add_prev5<std_ffn<ACT, DO, d_model, rms_norm<tag5<
74- add_prev1<multihead_attention<ACT, DO, seq_len, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>>>>;
74+ add_prev1<multihead_attention<ACT, DO, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>>>>;
7575
7676 template <long remaining_layers, template <typename > class ACT , template <typename > class DO ,
77- long seq_len, long d_model, long num_heads, typename SUBNET, typename enabled = void >
77+ long d_model, long num_heads, typename SUBNET, typename enabled = void >
7878 struct transformer_stack_impl
7979 {
80- using type = transformer_block<ACT, DO, seq_len, d_model, num_heads,
81- typename transformer_stack_impl<remaining_layers - 1 , ACT, DO, seq_len, d_model, num_heads, SUBNET>::type>;
80+ using type = transformer_block<ACT, DO, d_model, num_heads,
81+ typename transformer_stack_impl<remaining_layers - 1 , ACT, DO, d_model, num_heads, SUBNET>::type>;
8282 };
8383
8484 template <template <typename > class ACT , template <typename > class DO ,
85- long seq_len, long d_model, long num_heads, typename SUBNET>
86- struct transformer_stack_impl <0 , ACT, DO, seq_len, d_model, num_heads, SUBNET, void >
85+ long d_model, long num_heads, typename SUBNET>
86+ struct transformer_stack_impl <0 , ACT, DO, d_model, num_heads, SUBNET, void >
8787 {
8888 using type = tag10<SUBNET>;
8989 };
9090
9191 template <long num_layers, template <typename > class ACT , template <typename > class DO ,
92- long seq_len, long d_model, long num_heads, typename SUBNET>
93- using transformer_stack = typename transformer_stack_impl<num_layers, ACT, DO, seq_len, d_model, num_heads, SUBNET>::type;
92+ long d_model, long num_heads, typename SUBNET>
93+ using transformer_stack = typename transformer_stack_impl<num_layers, ACT, DO, d_model, num_heads, SUBNET>::type;
9494
9595 } // namespace std_transformer
9696
@@ -179,7 +179,6 @@ namespace dlib
179179 using l_net_type = L_NET;
180180
181181 explicit hrm_ () :
182- seq_len(0 ),
183182 hidden_dim(0 ),
184183 learning_rate_multiplier(1.0 )
185184 {
@@ -190,7 +189,6 @@ namespace dlib
190189 l_net (other.l_net),
191190 z_h_init (other.z_h_init),
192191 z_l_init (other.z_l_init),
193- seq_len (other.seq_len),
194192 hidden_dim (other.hidden_dim),
195193 learning_rate_multiplier (other.learning_rate_multiplier)
196194 {
@@ -203,7 +201,6 @@ namespace dlib
203201 l_net = other.l_net ;
204202 z_h_init = other.z_h_init ;
205203 z_l_init = other.z_l_init ;
206- seq_len = other.seq_len ;
207204 hidden_dim = other.hidden_dim ;
208205 learning_rate_multiplier = other.learning_rate_multiplier ;
209206 }
@@ -215,8 +212,7 @@ namespace dlib
215212 {
216213 const tensor& input = sub.get_output ();
217214
218- // Store dimensions for initialization
219- seq_len = input.nr ();
215+ // Store dimension for initialization
220216 hidden_dim = input.nc ();
221217
222218 // Initialize hidden states with truncated normal (std=1, trunc=2)
@@ -229,6 +225,7 @@ namespace dlib
229225 const tensor& x = sub.get_output ();
230226 const long batch_size = x.num_samples ();
231227 const long k = x.k ();
228+ const long seq_len = x.nr ();
232229
233230 // Allocate working tensors with proper batch size
234231 z_h_current.copy_size (x);
@@ -356,7 +353,6 @@ namespace dlib
356353 serialize (item.l_net , out);
357354 serialize (item.z_h_init , out);
358355 serialize (item.z_l_init , out);
359- serialize (item.seq_len , out);
360356 serialize (item.hidden_dim , out);
361357 serialize (item.learning_rate_multiplier , out);
362358 }
@@ -372,7 +368,6 @@ namespace dlib
372368 deserialize (item.l_net , in);
373369 deserialize (item.z_h_init , in);
374370 deserialize (item.z_l_init , in);
375- deserialize (item.seq_len , in);
376371 deserialize (item.hidden_dim , in);
377372 deserialize (item.learning_rate_multiplier , in);
378373 }
@@ -449,7 +444,6 @@ namespace dlib
449444 resizable_tensor z_l_init;
450445
451446 // Dimensions and learning rate
452- long seq_len;
453447 long hidden_dim;
454448 double learning_rate_multiplier;
455449
@@ -473,7 +467,7 @@ namespace dlib
473467
474468 // Gate network: produces raw logits for expert selection
475469 template <long num_experts, template <typename > class DO , typename SUBNET>
476- using gate = fc<num_experts, DO<leaky_relu<fc<num_experts * 8 , SUBNET>>>>;
470+ using gate = fc<num_experts, DO<leaky_relu<fc<num_experts * 8 , avg_pool_everything< SUBNET> >>>>;
477471
478472 struct training_mode_tag {};
479473 struct inference_mode_tag {};
0 commit comments