Skip to content

Commit e2c229d

Browse files
committed
Embeddings class improvement
1 parent f028608 commit e2c229d

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

dlib/dnn/layers.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5454,7 +5454,8 @@ namespace dlib
54545454
embeddings_() : num_embeddings(num_embeddings_),
54555455
embedding_dim(embedding_dim_),
54565456
learning_rate_multiplier(1.0f),
5457-
scale_by_freq(true)
5457+
scale_by_freq(true),
5458+
output_scale(std::sqrt(static_cast<float>(embedding_dim_)))
54585459
{
54595460
}
54605461

@@ -5486,12 +5487,17 @@ namespace dlib
54865487
}
54875488
}
54885489

5490+
float get_output_scale() const { return output_scale; }
5491+
54895492
template <typename SUBNET>
54905493
void setup(const SUBNET& /*sub*/)
54915494
{
54925495
embs.set_size(num_embeddings, embedding_dim);
54935496
tt::tensor_rand rnd(std::rand());
54945497
rnd.fill_gaussian(embs);
5498+
5499+
const float init_scale = 1.0f / std::sqrt(static_cast<float>(embedding_dim));
5500+
tt::affine_transform(embs, embs, init_scale);
54955501
}
54965502

54975503
template <typename SUBNET>
@@ -5501,6 +5507,7 @@ namespace dlib
55015507
output.set_size(prev.num_samples(), prev.k(), prev.nr(), embedding_dim);
55025508

55035509
tt::embeddings(output, prev, embs);
5510+
tt::affine_transform(output, output, output_scale);
55045511
}
55055512

55065513
template <typename SUBNET>
@@ -5515,7 +5522,8 @@ namespace dlib
55155522
auto& prev_src = sub.get_output();
55165523

55175524
calc_token_freqs(prev_src, gradient_input);
5518-
tt::embeddings_gradient(prev_src, gradient_input, embs, freqs, learning_rate_multiplier, scale_by_freq);
5525+
const float scaled_lr = learning_rate_multiplier * output_scale;
5526+
tt::embeddings_gradient(prev_src, gradient_input, embs, freqs, scaled_lr, scale_by_freq);
55195527
}
55205528
}
55215529

@@ -5533,6 +5541,7 @@ namespace dlib
55335541
serialize(item.embedding_dim, out);
55345542
serialize(item.learning_rate_multiplier, out);
55355543
serialize(item.scale_by_freq, out);
5544+
serialize(item.output_scale, out);
55365545
}
55375546
friend void deserialize(embeddings_& item, std::istream& in)
55385547
{
@@ -5545,19 +5554,22 @@ namespace dlib
55455554
deserialize(item.embedding_dim, in);
55465555
deserialize(item.learning_rate_multiplier, in);
55475556
deserialize(item.scale_by_freq, in);
5557+
deserialize(item.output_scale, in);
55485558
}
55495559

55505560
friend std::ostream& operator<<(std::ostream& out, const embeddings_& item)
55515561
{
55525562
out << "embeddings (num_embeddings=" << item.num_embeddings
55535563
<< ", embedding_dim=" << item.embedding_dim
5564+
<< ", scale=" << item.output_scale
55545565
<< ") learning_rate_mult=" << item.learning_rate_multiplier;
55555566
return out;
55565567
}
55575568
friend void to_xml(const embeddings_& item, std::ostream& out)
55585569
{
55595570
out << "<embeddings num_embeddings='" << item.num_embeddings
55605571
<< "' embedding_dim='" << item.embedding_dim
5572+
<< "' output_scale='" << item.output_scale
55615573
<< "' learning_rate_mult='"
55625574
<< item.learning_rate_multiplier << "'>\n";
55635575
out << mat(item.embs);
@@ -5589,6 +5601,7 @@ namespace dlib
55895601
unsigned long num_embeddings, embedding_dim;
55905602
double learning_rate_multiplier;
55915603
bool scale_by_freq;
5604+
float output_scale;
55925605
};
55935606

55945607
template <

0 commit comments

Comments
 (0)