Skip to content

Commit 7cce339

Browse files
committed
Update
1 parent 78aae5b commit 7cce339

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

dlib/dnn/transformer.h

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ namespace dlib
180180

181181
explicit hrm_() :
182182
seq_len(0),
183-
hidden_dim(0)
183+
hidden_dim(0),
184+
learning_rate_multiplier(1.0)
184185
{
185186
}
186187

@@ -190,7 +191,8 @@ namespace dlib
190191
z_h_init(other.z_h_init),
191192
z_l_init(other.z_l_init),
192193
seq_len(other.seq_len),
193-
hidden_dim(other.hidden_dim)
194+
hidden_dim(other.hidden_dim),
195+
learning_rate_multiplier(other.learning_rate_multiplier)
194196
{
195197
}
196198

@@ -203,6 +205,7 @@ namespace dlib
203205
z_l_init = other.z_l_init;
204206
seq_len = other.seq_len;
205207
hidden_dim = other.hidden_dim;
208+
learning_rate_multiplier = other.learning_rate_multiplier;
206209
}
207210
return *this;
208211
}
@@ -321,6 +324,15 @@ namespace dlib
321324
tt::add(1.0f, prev_grad, 1.0f, grad_l);
322325
}
323326

327+
void set_learning_rate_multiplier(double val)
328+
{
329+
learning_rate_multiplier = val;
330+
set_all_learning_rate_multipliers(h_net, val);
331+
set_all_learning_rate_multipliers(l_net, val);
332+
}
333+
double get_learning_rate_multiplier() const { return learning_rate_multiplier; }
334+
335+
324336
// Cleans up the internal state of H and L networks
325337
void clean()
326338
{
@@ -346,6 +358,7 @@ namespace dlib
346358
serialize(item.z_l_init, out);
347359
serialize(item.seq_len, out);
348360
serialize(item.hidden_dim, out);
361+
serialize(item.learning_rate_multiplier, out);
349362
}
350363

351364
friend void deserialize(hrm_& item, std::istream& in)
@@ -361,17 +374,26 @@ namespace dlib
361374
deserialize(item.z_l_init, in);
362375
deserialize(item.seq_len, in);
363376
deserialize(item.hidden_dim, in);
377+
deserialize(item.learning_rate_multiplier, in);
364378
}
365379

366380
friend std::ostream& operator<<(std::ostream& out, const hrm_& item)
367381
{
368-
out << "hrm (N=" << N << ", T=" << T << ")";
382+
out << "hrm\t ("
383+
<< "N=" << N
384+
<< ", T=" << T
385+
<< ")";
386+
out << " learning_rate_mult=" << item.learning_rate_multiplier;
369387
return out;
370388
}
371389

372390
friend void to_xml(const hrm_& item, std::ostream& out)
373391
{
374-
out << "<hrm N='" << N << "' T='" << T << "'>\n";
392+
out << "<hrm"
393+
<< " N='" << N << "'"
394+
<< " T='" << T << "'"
395+
<< " learning_rate_mult='" << item.learning_rate_multiplier << "'"
396+
<< ">\n";
375397
out << " <h_module>\n";
376398
to_xml(item.h_net, out);
377399
out << " </h_module>\n";
@@ -426,9 +448,10 @@ namespace dlib
426448
resizable_tensor z_h_init;
427449
resizable_tensor z_l_init;
428450

429-
// Dimensions
451+
// Dimensions and learning rate
430452
long seq_len;
431453
long hidden_dim;
454+
double learning_rate_multiplier;
432455

433456
// Temporary computation tensors
434457
resizable_tensor z_h_current;
@@ -846,6 +869,7 @@ namespace dlib
846869
for (auto& expert : experts)
847870
set_all_learning_rate_multipliers(expert, val);
848871
}
872+
double get_learning_rate_multiplier() const { return learning_rate_multiplier; }
849873

850874
// Direct access to expert networks (for inspection/debugging)
851875
EXPERT_NET& get_expert(size_t idx) {
@@ -900,24 +924,42 @@ namespace dlib
900924
friend std::ostream& operator<<(std::ostream& out, const moe_& item)
901925
{
902926
const bool is_training = std::is_same<MODE, training_mode_tag>::value;
903-
out << "moe"
904-
<< " (experts=" << item.n_experts
927+
out << "moe\t ("
928+
<< "experts=" << item.n_experts
905929
<< ", top_k=" << item.top_k
906930
<< ", mode=" << (is_training ? "train" : "infer")
907-
<< ", noise=" << item.noise_scale << ")"
908-
<< ", lb=" << item.load_balance_weight << ")";
931+
<< ", noise=" << item.noise_scale
932+
<< ", lb=" << item.load_balance_weight
933+
<< ")";
934+
out << " learning_rate_mult=" << item.learning_rate_multiplier;
909935
return out;
910936
}
911937

912938
friend void to_xml(const moe_& item, std::ostream& out)
913939
{
914940
const bool is_training = std::is_same<MODE, training_mode_tag>::value;
915-
out << "<moe>\n";
916-
out << " <num_experts>" << item.n_experts << "</num_experts>\n";
917-
out << " <top_k>" << item.top_k << "</top_k>\n";
918-
out << " <noise_scale>" << item.noise_scale << "</noise_scale>\n";
919-
out << " <load_balance_weight>" << item.load_balance_weight << "</load_balance_weight>\n";
920-
out << " <mode>" << (is_training ? "training" : "inference") << "</mode>\n";
941+
out << "<moe"
942+
<< " num_experts='" << item.n_experts << "'"
943+
<< " top_k='" << item.top_k << "'"
944+
<< " noise_scale='" << item.noise_scale << "'"
945+
<< " usage_update_rate='" << item.usage_update_rate << "'"
946+
<< " load_balance_weight='" << item.load_balance_weight << "'"
947+
<< " learning_rate_mult='" << item.learning_rate_multiplier << "'"
948+
<< " mode='" << (is_training ? "training" : "inference") << "'"
949+
<< ">\n";
950+
for (size_t i = 0; i < item.experts.size(); ++i)
951+
{
952+
out << "<expert index='" << i << "'>\n";
953+
to_xml(item.experts[i], out);
954+
out << "</expert>\n";
955+
}
956+
out << "<expert_usage>";
957+
for (size_t i = 0; i < item.expert_usage.size(); ++i)
958+
{
959+
if (i > 0) out << " ";
960+
out << item.expert_usage[i];
961+
}
962+
out << "</expert_usage>\n";
921963
out << "</moe>\n";
922964
}
923965

examples/slm_mixture_of_experts_ex.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ int main(int argc, char** argv)
610610
trainer.set_min_learning_rate(1e-6);
611611
trainer.set_mini_batch_size(batch_size);
612612
trainer.set_iterations_without_progress_threshold(patience);
613-
trainer.set_synchronization_file("chkpt-" + model_file, std::chrono::minutes(10));
613+
trainer.set_synchronization_file("chkpt-" + model_file, std::chrono::minutes(5));
614614
trainer.be_quiet();
615615
cout << net << endl << endl; // Show the model architecture
616616
cout << "Starting training...\n";

0 commit comments

Comments
 (0)