@@ -5454,7 +5454,8 @@ namespace dlib
54545454 embeddings_ () : num_embeddings(num_embeddings_),
54555455 embedding_dim (embedding_dim_),
54565456 learning_rate_multiplier(1 .0f ),
5457- scale_by_freq(true )
5457+ scale_by_freq(true ),
5458+ output_scale(std::sqrt(static_cast <float >(embedding_dim_)))
54585459 {
54595460 }
54605461
@@ -5486,12 +5487,17 @@ namespace dlib
54865487 }
54875488 }
54885489
5490+ float get_output_scale () const { return output_scale; }
5491+
54895492 template <typename SUBNET>
54905493 void setup (const SUBNET& /* sub*/ )
54915494 {
54925495 embs.set_size (num_embeddings, embedding_dim);
54935496 tt::tensor_rand rnd (std::rand ());
54945497 rnd.fill_gaussian (embs);
5498+
5499+ const float init_scale = 1 .0f / std::sqrt (static_cast <float >(embedding_dim));
5500+ tt::affine_transform (embs, embs, init_scale);
54955501 }
54965502
54975503 template <typename SUBNET>
@@ -5501,6 +5507,7 @@ namespace dlib
55015507 output.set_size (prev.num_samples (), prev.k (), prev.nr (), embedding_dim);
55025508
55035509 tt::embeddings (output, prev, embs);
5510+ tt::affine_transform (output, output, output_scale);
55045511 }
55055512
55065513 template <typename SUBNET>
@@ -5515,7 +5522,8 @@ namespace dlib
55155522 auto & prev_src = sub.get_output ();
55165523
55175524 calc_token_freqs (prev_src, gradient_input);
5518- tt::embeddings_gradient (prev_src, gradient_input, embs, freqs, learning_rate_multiplier, scale_by_freq);
5525+ const float scaled_lr = learning_rate_multiplier * output_scale;
5526+ tt::embeddings_gradient (prev_src, gradient_input, embs, freqs, scaled_lr, scale_by_freq);
55195527 }
55205528 }
55215529
@@ -5533,6 +5541,7 @@ namespace dlib
55335541 serialize (item.embedding_dim , out);
55345542 serialize (item.learning_rate_multiplier , out);
55355543 serialize (item.scale_by_freq , out);
5544+ serialize (item.output_scale , out);
55365545 }
55375546 friend void deserialize (embeddings_& item, std::istream& in)
55385547 {
@@ -5545,19 +5554,22 @@ namespace dlib
55455554 deserialize (item.embedding_dim , in);
55465555 deserialize (item.learning_rate_multiplier , in);
55475556 deserialize (item.scale_by_freq , in);
5557+ deserialize (item.output_scale , in);
55485558 }
55495559
55505560 friend std::ostream& operator <<(std::ostream& out, const embeddings_& item)
55515561 {
55525562 out << " embeddings (num_embeddings=" << item.num_embeddings
55535563 << " , embedding_dim=" << item.embedding_dim
5564+ << " , scale=" << item.output_scale
55545565 << " ) learning_rate_mult=" << item.learning_rate_multiplier ;
55555566 return out;
55565567 }
55575568 friend void to_xml (const embeddings_& item, std::ostream& out)
55585569 {
55595570 out << " <embeddings num_embeddings='" << item.num_embeddings
55605571 << " ' embedding_dim='" << item.embedding_dim
5572+ << " ' output_scale='" << item.output_scale
55615573 << " ' learning_rate_mult='"
55625574 << item.learning_rate_multiplier << " '>\n " ;
55635575 out << mat (item.embs );
@@ -5589,6 +5601,7 @@ namespace dlib
55895601 unsigned long num_embeddings, embedding_dim;
55905602 double learning_rate_multiplier;
55915603 bool scale_by_freq;
5604+ float output_scale;
55925605 };
55935606
55945607 template <
0 commit comments