@@ -180,7 +180,8 @@ namespace dlib
180180
181181 explicit hrm_ () :
182182 seq_len(0 ),
183- hidden_dim(0 )
183+ hidden_dim(0 ),
184+ learning_rate_multiplier(1.0 )
184185 {
185186 }
186187
@@ -190,7 +191,8 @@ namespace dlib
190191 z_h_init (other.z_h_init),
191192 z_l_init (other.z_l_init),
192193 seq_len (other.seq_len),
193- hidden_dim (other.hidden_dim)
194+ hidden_dim (other.hidden_dim),
195+ learning_rate_multiplier (other.learning_rate_multiplier)
194196 {
195197 }
196198
@@ -203,6 +205,7 @@ namespace dlib
203205 z_l_init = other.z_l_init ;
204206 seq_len = other.seq_len ;
205207 hidden_dim = other.hidden_dim ;
208+ learning_rate_multiplier = other.learning_rate_multiplier ;
206209 }
207210 return *this ;
208211 }
@@ -321,6 +324,15 @@ namespace dlib
321324 tt::add (1 .0f , prev_grad, 1 .0f , grad_l);
322325 }
323326
327+ void set_learning_rate_multiplier (double val)
328+ {
329+ learning_rate_multiplier = val;
330+ set_all_learning_rate_multipliers (h_net, val);
331+ set_all_learning_rate_multipliers (l_net, val);
332+ }
333+ double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
334+
335+
324336 // Cleans up the internal state of H and L networks
325337 void clean ()
326338 {
@@ -346,6 +358,7 @@ namespace dlib
346358 serialize (item.z_l_init , out);
347359 serialize (item.seq_len , out);
348360 serialize (item.hidden_dim , out);
361+ serialize (item.learning_rate_multiplier , out);
349362 }
350363
351364 friend void deserialize (hrm_& item, std::istream& in)
@@ -361,17 +374,26 @@ namespace dlib
361374 deserialize (item.z_l_init , in);
362375 deserialize (item.seq_len , in);
363376 deserialize (item.hidden_dim , in);
377+ deserialize (item.learning_rate_multiplier , in);
364378 }
365379
366380 friend std::ostream& operator <<(std::ostream& out, const hrm_& item)
367381 {
368- out << " hrm (N=" << N << " , T=" << T << " )" ;
382+ out << " hrm\t ("
383+ << " N=" << N
384+ << " , T=" << T
385+ << " )" ;
386+ out << " learning_rate_mult=" << item.learning_rate_multiplier ;
369387 return out;
370388 }
371389
372390 friend void to_xml (const hrm_& item, std::ostream& out)
373391 {
374- out << " <hrm N='" << N << " ' T='" << T << " '>\n " ;
392+ out << " <hrm"
393+ << " N='" << N << " '"
394+ << " T='" << T << " '"
395+ << " learning_rate_mult='" << item.learning_rate_multiplier << " '"
396+ << " >\n " ;
375397 out << " <h_module>\n " ;
376398 to_xml (item.h_net , out);
377399 out << " </h_module>\n " ;
@@ -426,9 +448,10 @@ namespace dlib
426448 resizable_tensor z_h_init;
427449 resizable_tensor z_l_init;
428450
429- // Dimensions
451+ // Dimensions and learning rate
430452 long seq_len;
431453 long hidden_dim;
454+ double learning_rate_multiplier;
432455
433456 // Temporary computation tensors
434457 resizable_tensor z_h_current;
@@ -846,6 +869,7 @@ namespace dlib
846869 for (auto & expert : experts)
847870 set_all_learning_rate_multipliers (expert, val);
848871 }
872+ double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
849873
850874 // Direct access to expert networks (for inspection/debugging)
851875 EXPERT_NET& get_expert (size_t idx) {
@@ -900,24 +924,42 @@ namespace dlib
900924 friend std::ostream& operator <<(std::ostream& out, const moe_& item)
901925 {
902926 const bool is_training = std::is_same<MODE, training_mode_tag>::value;
903- out << " moe"
904- << " ( experts=" << item.n_experts
927+ out << " moe\t ( "
928+ << " experts=" << item.n_experts
905929 << " , top_k=" << item.top_k
906930 << " , mode=" << (is_training ? " train" : " infer" )
907- << " , noise=" << item.noise_scale << " )"
908- << " , lb=" << item.load_balance_weight << " )" ;
931+ << " , noise=" << item.noise_scale
932+ << " , lb=" << item.load_balance_weight
933+ << " )" ;
934+ out << " learning_rate_mult=" << item.learning_rate_multiplier ;
909935 return out;
910936 }
911937
912938 friend void to_xml (const moe_& item, std::ostream& out)
913939 {
914940 const bool is_training = std::is_same<MODE, training_mode_tag>::value;
915- out << " <moe>\n " ;
916- out << " <num_experts>" << item.n_experts << " </num_experts>\n " ;
917- out << " <top_k>" << item.top_k << " </top_k>\n " ;
918- out << " <noise_scale>" << item.noise_scale << " </noise_scale>\n " ;
919- out << " <load_balance_weight>" << item.load_balance_weight << " </load_balance_weight>\n " ;
920- out << " <mode>" << (is_training ? " training" : " inference" ) << " </mode>\n " ;
941+ out << " <moe"
942+ << " num_experts='" << item.n_experts << " '"
943+ << " top_k='" << item.top_k << " '"
944+ << " noise_scale='" << item.noise_scale << " '"
945+ << " usage_update_rate='" << item.usage_update_rate << " '"
946+ << " load_balance_weight='" << item.load_balance_weight << " '"
947+ << " learning_rate_mult='" << item.learning_rate_multiplier << " '"
948+ << " mode='" << (is_training ? " training" : " inference" ) << " '"
949+ << " >\n " ;
950+ for (size_t i = 0 ; i < item.experts .size (); ++i)
951+ {
952+ out << " <expert index='" << i << " '>\n " ;
953+ to_xml (item.experts [i], out);
954+ out << " </expert>\n " ;
955+ }
956+ out << " <expert_usage>" ;
957+ for (size_t i = 0 ; i < item.expert_usage .size (); ++i)
958+ {
959+ if (i > 0 ) out << " " ;
960+ out << item.expert_usage [i];
961+ }
962+ out << " </expert_usage>\n " ;
921963 out << " </moe>\n " ;
922964 }
923965
0 commit comments