Skip to content

Commit 78aae5b

Browse files
committed
Update
1 parent 5a83f2c commit 78aae5b

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

dlib/dnn/transformer.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ namespace dlib
484484
top_k(top_e),
485485
usage_update_rate(0.05f),
486486
load_balance_weight(0.01f),
487+
learning_rate_multiplier(1.0),
487488
cached_batch_size_(0)
488489
{
489490
}
@@ -494,6 +495,7 @@ namespace dlib
494495
top_k(other.top_k),
495496
usage_update_rate(other.usage_update_rate),
496497
load_balance_weight(other.load_balance_weight),
498+
learning_rate_multiplier(other.learning_rate_multiplier),
497499
expert_usage(other.expert_usage),
498500
cached_batch_size_(0)
499501
{
@@ -511,6 +513,7 @@ namespace dlib
511513
top_k = other.top_k;
512514
usage_update_rate = other.usage_update_rate;
513515
load_balance_weight = other.load_balance_weight;
516+
learning_rate_multiplier = other.learning_rate_multiplier;
514517
expert_usage = other.expert_usage;
515518
cached_batch_size_ = 0;
516519

@@ -793,7 +796,8 @@ namespace dlib
793796
}
794797
}
795798

796-
if (std::is_same<MODE, training_mode_tag>::value && load_balance_weight > 0) {
799+
if (std::is_same<MODE, training_mode_tag>::value && load_balance_weight > 0
800+
&& learning_rate_multiplier > 0) {
797801
tensor& gate_grad = layer<TAG>(sub).get_gradient_input();
798802
float* gate_grad_data = gate_grad.host();
799803

@@ -838,6 +842,7 @@ namespace dlib
838842

839843
void set_learning_rate_multiplier(double val)
840844
{
845+
learning_rate_multiplier = val;
841846
for (auto& expert : experts)
842847
set_all_learning_rate_multipliers(expert, val);
843848
}
@@ -868,6 +873,7 @@ namespace dlib
868873
serialize(item.noise_scale, out);
869874
serialize(item.usage_update_rate, out);
870875
serialize(item.load_balance_weight, out);
876+
serialize(item.learning_rate_multiplier, out);
871877
serialize(item.experts, out);
872878
serialize(item.expert_usage, out);
873879
}
@@ -884,6 +890,7 @@ namespace dlib
884890
deserialize(item.noise_scale, in);
885891
deserialize(item.usage_update_rate, in);
886892
deserialize(item.load_balance_weight, in);
893+
deserialize(item.learning_rate_multiplier, in);
887894
deserialize(item.experts, in);
888895
deserialize(item.expert_usage, in);
889896

@@ -928,11 +935,12 @@ namespace dlib
928935
}
929936

930937
// Configuration
931-
long n_experts; // Number of expert networks
932-
float noise_scale; // Gaussian noise std for exploration
933-
long top_k; // Number of experts to activate per sample
934-
float usage_update_rate; // EMA smoothing rate for usage tracking
935-
float load_balance_weight; // Auxiliary loss coefficient for expert load balancing
938+
long n_experts; // Number of expert networks
939+
float noise_scale; // Gaussian noise std for exploration
940+
long top_k; // Number of experts to activate per sample
941+
float usage_update_rate; // EMA smoothing rate for usage tracking
942+
float load_balance_weight; // Auxiliary loss coefficient for expert load balancing
943+
double learning_rate_multiplier;
936944

937945
// Expert networks
938946
std::vector<EXPERT_NET> experts;

0 commit comments

Comments
 (0)