@@ -484,6 +484,7 @@ namespace dlib
484484 top_k(top_e),
485485 usage_update_rate(0 .05f ),
486486 load_balance_weight(0 .01f ),
487+ learning_rate_multiplier(1.0 ),
487488 cached_batch_size_(0 )
488489 {
489490 }
@@ -494,6 +495,7 @@ namespace dlib
494495 top_k (other.top_k),
495496 usage_update_rate (other.usage_update_rate),
496497 load_balance_weight (other.load_balance_weight),
498+ learning_rate_multiplier (other.learning_rate_multiplier),
497499 expert_usage (other.expert_usage),
498500 cached_batch_size_ (0 )
499501 {
@@ -511,6 +513,7 @@ namespace dlib
511513 top_k = other.top_k ;
512514 usage_update_rate = other.usage_update_rate ;
513515 load_balance_weight = other.load_balance_weight ;
516+ learning_rate_multiplier = other.learning_rate_multiplier ;
514517 expert_usage = other.expert_usage ;
515518 cached_batch_size_ = 0 ;
516519
@@ -793,7 +796,8 @@ namespace dlib
793796 }
794797 }
795798
796- if (std::is_same<MODE, training_mode_tag>::value && load_balance_weight > 0 ) {
799+ if (std::is_same<MODE, training_mode_tag>::value && load_balance_weight > 0
800+ && learning_rate_multiplier > 0 ) {
797801 tensor& gate_grad = layer<TAG>(sub).get_gradient_input ();
798802 float * gate_grad_data = gate_grad.host ();
799803
@@ -838,6 +842,7 @@ namespace dlib
838842
839843 void set_learning_rate_multiplier (double val)
840844 {
845+ learning_rate_multiplier = val;
841846 for (auto & expert : experts)
842847 set_all_learning_rate_multipliers (expert, val);
843848 }
@@ -868,6 +873,7 @@ namespace dlib
868873 serialize (item.noise_scale , out);
869874 serialize (item.usage_update_rate , out);
870875 serialize (item.load_balance_weight , out);
876+ serialize (item.learning_rate_multiplier , out);
871877 serialize (item.experts , out);
872878 serialize (item.expert_usage , out);
873879 }
@@ -884,6 +890,7 @@ namespace dlib
884890 deserialize (item.noise_scale , in);
885891 deserialize (item.usage_update_rate , in);
886892 deserialize (item.load_balance_weight , in);
893+ deserialize (item.learning_rate_multiplier , in);
887894 deserialize (item.experts , in);
888895 deserialize (item.expert_usage , in);
889896
@@ -928,11 +935,12 @@ namespace dlib
928935 }
929936
930937 // Configuration
931- long n_experts; // Number of expert networks
932- float noise_scale; // Gaussian noise std for exploration
933- long top_k; // Number of experts to activate per sample
934- float usage_update_rate; // EMA smoothing rate for usage tracking
935- float load_balance_weight; // Auxiliary loss coefficient for expert load balancing
938+ long n_experts; // Number of expert networks
939+ float noise_scale; // Gaussian noise std for exploration
940+ long top_k; // Number of experts to activate per sample
941+ float usage_update_rate; // EMA smoothing rate for usage tracking
942+ float load_balance_weight; // Auxiliary loss coefficient for expert load balancing
943+ double learning_rate_multiplier;
936944
937945 // Expert networks
938946 std::vector<EXPERT_NET> experts;
0 commit comments