Skip to content

Commit 0ca05c7

Browse files
authored
Merge branch 'master' into softmaxm-layer
2 parents ee07b5f + 4e53f83 commit 0ca05c7

File tree

4 files changed

+347
-4
lines changed

4 files changed

+347
-4
lines changed

dlib/dnn/layers.h

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4803,8 +4803,132 @@ namespace dlib
48034803

48044804
// ----------------------------------------------------------------------------------------
48054805

4806-
}
4806+
struct neg_infinity_tag {};
4807+
struct zero_tag {};
4808+
4809+
template<typename T>
4810+
struct is_special_value : std::false_type {};
4811+
template<>
4812+
struct is_special_value<neg_infinity_tag> : std::true_type {};
4813+
template<>
4814+
struct is_special_value<zero_tag> : std::true_type {};
4815+
4816+
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
4817+
class tril_
4818+
{
4819+
public:
4820+
tril_(): diag(diag_), diag_value(compute_diag_value()) {}
4821+
4822+
template <typename SUBNET>
4823+
void setup(const SUBNET& /*sub*/)
4824+
{
4825+
}
4826+
4827+
template <typename SUBNET>
4828+
void forward(const SUBNET& sub, resizable_tensor& output)
4829+
{
4830+
auto& prev = sub.get_output();
4831+
output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc());
4832+
4833+
check_mask(prev);
4834+
tt::multiply(false, output, prev, binary_mask);
4835+
if (diag_value != 0.0f) tt::add(1, output, 1, output_mask);
4836+
}
4837+
template <typename SUBNET>
4838+
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
4839+
{
4840+
auto& prev_grad = sub.get_gradient_input();
4841+
tt::multiply(true, prev_grad, gradient_input, binary_mask);
4842+
}
4843+
4844+
inline dpoint map_input_to_output(const dpoint& p) const { return p; }
4845+
inline dpoint map_output_to_input(const dpoint& p) const { return p; }
4846+
4847+
const tensor& get_layer_params() const { return params; }
4848+
tensor& get_layer_params() { return params; }
4849+
4850+
friend void serialize(const tril_& item, std::ostream& out)
4851+
{
4852+
serialize("tril_", out);
4853+
serialize(item.diag, out);
4854+
serialize(item.diag_value, out);
4855+
}
4856+
friend void deserialize(tril_& item, std::istream& in)
4857+
{
4858+
std::string version;
4859+
deserialize(version, in);
4860+
if (version != "tril_")
4861+
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_.");
4862+
deserialize(item.diag, in);
4863+
deserialize(item.diag_value, in);
4864+
}
4865+
4866+
friend std::ostream& operator<<(std::ostream& out, const tril_& item)
4867+
{
4868+
out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")";
4869+
return out;
4870+
}
4871+
friend void to_xml(const tril_& item, std::ostream& out)
4872+
{
4873+
out << "<tril diag='" << item.diag << "' diag_value='" << item.diag_value << "'/>\n";
4874+
}
48074875

4808-
#endif // DLIB_DNn_LAYERS_H_
4876+
private:
4877+
float compute_diag_value() const {
4878+
if (std::is_same<tag_, neg_infinity_tag>::value)
4879+
return -std::numeric_limits<float>::infinity();
4880+
else if (std::is_same<tag_, zero_tag>::value)
4881+
return 0.0f;
4882+
else
4883+
return static_cast<float>(num_) / static_cast<float>(den_);
4884+
}
48094885

4886+
void check_mask(const tensor& t)
4887+
{
4888+
if (!have_same_dimensions(binary_mask, t)) {
4889+
binary_mask.copy_size(t);
4890+
binary_mask = 1;
4891+
if (diag_value != 0.0f) {
4892+
output_mask.copy_size(t);
4893+
output_mask = 0;
4894+
}
4895+
for (long s = 0; s < output_mask.num_samples(); ++s)
4896+
{
4897+
for (long k = 0; k < output_mask.k(); ++k)
4898+
{
4899+
for (long r = 0; r < output_mask.nr(); ++r)
4900+
{
4901+
for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c)
4902+
{
4903+
if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value;
4904+
binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0;
4905+
}
4906+
}
4907+
}
4908+
}
4909+
}
4910+
}
4911+
4912+
template <typename T>
4913+
struct always_false : std::false_type {};
4914+
4915+
resizable_tensor params; // unused
4916+
resizable_tensor binary_mask, output_mask;
4917+
long diag;
4918+
float diag_value;
4919+
};
4920+
4921+
template <typename SUBNET>
4922+
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
4923+
4924+
template <typename SUBNET>
4925+
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
4926+
4927+
template <long diag, long num, long den, typename SUBNET>
4928+
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
4929+
4930+
// ----------------------------------------------------------------------------------------
4931+
4932+
}
48104933

4934+
#endif // DLIB_DNn_LAYERS_H_

dlib/dnn/layers_abstract.h

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3884,7 +3884,162 @@ namespace dlib
38843884

38853885
// ----------------------------------------------------------------------------------------
38863886

3887-
}
3887+
struct neg_infinity_tag {};
3888+
struct zero_tag {};
3889+
3890+
template<typename T>
3891+
struct is_special_value : std::false_type {};
3892+
template<>
3893+
struct is_special_value<neg_infinity_tag> : std::true_type {};
3894+
template<>
3895+
struct is_special_value<zero_tag> : std::true_type {};
3896+
3897+
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
3898+
class tril_
3899+
{
3900+
/*!
3901+
TEMPLATE PARAMETERS
3902+
- diag_: A long integer specifying the diagonal offset.
3903+
- tag_: A type tag specifying special values or void for numeric values.
3904+
- num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void).
3905+
- den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void).
3906+
3907+
REQUIREMENTS
3908+
- diag_ must be an integer.
3909+
- tag_ must be either neg_infinity_tag, zero_tag, or void.
3910+
- If tag_ is void, num_ and den_ are used to compute the diagonal value.
3911+
- If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored.
3912+
3913+
WHAT THIS OBJECT REPRESENTS
3914+
This object implements a layer in a deep neural network that applies a lower triangular mask to
3915+
its input tensor. The mask is defined such that all elements above the specified diagonal are set
3916+
to a given value. The diagonal offset and the mask value are determined by the template parameters.
3917+
3918+
DIAGONAL VALUE DETERMINATION
3919+
- If tag_ is neg_infinity_tag: diagonal value is set to negative infinity.
3920+
- If tag_ is zero_tag: diagonal value is set to zero.
3921+
- If tag_ is void: diagonal value is set to num_ / den_ as a float.
3922+
3923+
DIAGONAL OFFSET
3924+
The diag_ parameter determines the diagonal above which elements are masked:
3925+
- diag_ = 0: main diagonal
3926+
- diag_ > 0: diag_ steps above the main diagonal
3927+
- diag_ < 0: |diag_| steps below the main diagonal
3928+
3929+
EXAMPLE USAGE
3930+
// Create a layer that masks all elements above the main diagonal with -inf
3931+
tril_<0, neg_infinity_tag> layer1;
3932+
3933+
// Create a layer that masks all elements above the main diagonal with 0
3934+
tril_<0, zero_tag> layer2;
3935+
3936+
// Create a layer that masks all elements above the main diagonal with 0.5
3937+
tril_<0, void, 1, 2> layer3;
3938+
3939+
// Create a layer that masks all elements 5 positions above the main diagonal with -inf
3940+
tril_<5, neg_infinity_tag> layer4;
3941+
3942+
// Create a layer that masks all elements 3 positions below the main diagonal with 0.25
3943+
tril_<-3, void, 1, 4> layer5;
3944+
3945+
SERIALIZATION SUPPORT
3946+
This object supports serialization and deserialization via the serialize() and deserialize() functions.
3947+
!*/
3948+
3949+
public:
3950+
tril_() = default;
3951+
/*!
3952+
ensures
3953+
- This object is properly initialized.
3954+
!*/
3955+
3956+
template <typename SUBNET>
3957+
void setup(const SUBNET& sub);
3958+
/*!
3959+
requires
3960+
- SUBNET is a valid network layer type.
3961+
ensures
3962+
- Initializes the mask based on the dimensions of the input tensor from sub.
3963+
!*/
3964+
3965+
template <typename SUBNET>
3966+
void forward(const SUBNET& sub, resizable_tensor& output);
3967+
/*!
3968+
requires
3969+
- SUBNET is a valid network layer type.
3970+
ensures
3971+
- Applies the lower triangular mask to the input tensor from sub and stores the result in output.
3972+
!*/
3973+
3974+
template <typename SUBNET>
3975+
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
3976+
/*!
3977+
requires
3978+
- SUBNET is a valid network layer type.
3979+
ensures
3980+
- Computes the gradient of the loss with respect to the input tensor and stores it in sub.
3981+
!*/
3982+
3983+
inline dpoint map_input_to_output(const dpoint& p) const;
3984+
/*!
3985+
ensures
3986+
- Maps a point from the input tensor to the corresponding point in the output tensor.
3987+
!*/
3988+
3989+
inline dpoint map_output_to_input(const dpoint& p) const;
3990+
/*!
3991+
ensures
3992+
- Maps a point from the output tensor to the corresponding point in the input tensor.
3993+
!*/
3994+
3995+
const tensor& get_layer_params() const;
3996+
/*!
3997+
ensures
3998+
- Returns the parameters of this layer.
3999+
!*/
4000+
4001+
tensor& get_layer_params();
4002+
/*!
4003+
ensures
4004+
- Returns the parameters of this layer.
4005+
!*/
4006+
4007+
friend void serialize(const tril_& item, std::ostream& out);
4008+
/*!
4009+
ensures
4010+
- Serializes the state of this object to the given output stream.
4011+
!*/
4012+
4013+
friend void deserialize(tril_& item, std::istream& in);
4014+
/*!
4015+
ensures
4016+
- Deserializes the state of this object from the given input stream.
4017+
!*/
4018+
4019+
friend std::ostream& operator<<(std::ostream& out, const tril_& item);
4020+
/*!
4021+
ensures
4022+
- Prints a human-readable representation of this object to the given output stream.
4023+
!*/
4024+
4025+
friend void to_xml(const tril_& item, std::ostream& out);
4026+
/*!
4027+
ensures
4028+
- Serializes the state of this object to XML format and writes it to the given output stream.
4029+
!*/
4030+
};
4031+
4032+
template <typename SUBNET>
4033+
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
4034+
4035+
template <typename SUBNET>
4036+
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
38884037

3889-
#endif // DLIB_DNn_LAYERS_ABSTRACT_H_
4038+
template <long diag, long num, long den, typename SUBNET>
4039+
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
4040+
4041+
// ----------------------------------------------------------------------------------------
4042+
4043+
}
38904044

4045+
#endif // DLIB_DNn_LAYERS_ABSTRACT_H_

dlib/dnn/visitors.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,22 @@ namespace dlib
10371037
update(i);
10381038
}
10391039

1040+
template <long diag, typename tag, long num, long den, typename U, typename E>
1041+
void operator()(size_t i, const add_layer<tril_<diag, tag, num, den>, U, E>&)
1042+
{
1043+
start_node(i, "tril");
1044+
out << " | {diag|{" << diag << "}}";
1045+
out << " | {diag_value|{";
1046+
1047+
if (std::is_same<tag, neg_infinity_tag>::value) out << "-inf";
1048+
else if (std::is_same<tag, zero_tag>::value) out << "0";
1049+
else out << static_cast<float>(num) / static_cast<float>(den);
1050+
1051+
out << "}}";
1052+
end_node();
1053+
update(i);
1054+
}
1055+
10401056
template <typename T, typename U, typename E>
10411057
void operator()(size_t i, const add_layer<T, U, E>&)
10421058
{

dlib/test/dnn.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,6 +2133,12 @@ void test_positional_encodings()
21332133
auto res = test_layer(l);
21342134
DLIB_TEST_MSG(res, res);
21352135
}
2136+
{
2137+
print_spinner();
2138+
tril_<-5, void, 1, 2> l;
2139+
auto res = test_layer(l);
2140+
DLIB_TEST_MSG(res, res);
2141+
}
21362142
{
21372143
print_spinner();
21382144
extract_<0,2,2,2> l;
@@ -4569,6 +4575,47 @@ void test_positional_encodings()
45694575
}
45704576
}
45714577

4578+
// ----------------------------------------------------------------------------------------
4579+
4580+
void test_tril()
4581+
{
4582+
print_spinner();
4583+
using net_type = tag1<tril_mask<tag2<input<matrix<float>>>>>;
4584+
net_type net;
4585+
4586+
// Input tensor
4587+
dlib::rand rnd;
4588+
const int nr = 2, nc = 3;
4589+
constexpr int n_samples = 3, k = 1;
4590+
std::vector<matrix<float>> x(n_samples);
4591+
matrix<float> xtmp(nr, nc);
4592+
for (int ii = 0; ii < n_samples; ++ii) {
4593+
for (int jj = 0; jj < nr; ++jj)
4594+
for (int kk = 0; kk < nc; ++kk)
4595+
xtmp(jj, kk) = rnd.get_random_gaussian();
4596+
x[ii] = xtmp;
4597+
}
4598+
4599+
// Convert input matrix to tensor
4600+
resizable_tensor input_tensor;
4601+
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
4602+
net.forward(input_tensor);
4603+
4604+
// Expected output tensor (manually set for comparison)
4605+
resizable_tensor expected_output;
4606+
expected_output.copy_size(input_tensor);
4607+
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
4608+
for (int ii = 0; ii < n_samples; ++ii) {
4609+
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
4610+
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
4611+
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
4612+
}
4613+
4614+
// Compare output tensor with expected output
4615+
auto& net_output = layer<tag1>(net).get_output();
4616+
DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5);
4617+
}
4618+
45724619
// ----------------------------------------------------------------------------------------
45734620

45744621
class dnn_tester : public tester
@@ -4651,6 +4698,7 @@ void test_positional_encodings()
46514698
test_rms_normalize();
46524699
test_transpose();
46534700
test_positional_encodings();
4701+
test_tril();
46544702
test_basic_tensor_ops();
46554703
test_layers();
46564704
test_visit_functions();

0 commit comments

Comments
 (0)