Skip to content

Commit 4e53f83

Browse files
authored
Add tril_ layer for lower triangular matrix operations (#3018)
* Add tril_ layer for lower triangular matrix operations * Improved layer consistency * Added constant_wrapper to fix the issue of the float in the template in c++17 * Looking for a solution for c++ 14 * Refactor tril_ layer for improved flexibility and C++14 compatibility * Updates * Updates * Updates * Updates * Updates * Updates
1 parent 72822fe commit 4e53f83

File tree

4 files changed

+346
-0
lines changed

4 files changed

+346
-0
lines changed

dlib/dnn/layers.h

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4696,6 +4696,132 @@ namespace dlib
46964696

46974697
template <typename SUBNET> using transpose = add_layer<transpose_, SUBNET>;
46984698

4699+
// ----------------------------------------------------------------------------------------
4700+
4701+
struct neg_infinity_tag {};
4702+
struct zero_tag {};
4703+
4704+
template<typename T>
4705+
struct is_special_value : std::false_type {};
4706+
template<>
4707+
struct is_special_value<neg_infinity_tag> : std::true_type {};
4708+
template<>
4709+
struct is_special_value<zero_tag> : std::true_type {};
4710+
4711+
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
4712+
class tril_
4713+
{
4714+
public:
4715+
tril_(): diag(diag_), diag_value(compute_diag_value()) {}
4716+
4717+
template <typename SUBNET>
4718+
void setup(const SUBNET& /*sub*/)
4719+
{
4720+
}
4721+
4722+
template <typename SUBNET>
4723+
void forward(const SUBNET& sub, resizable_tensor& output)
4724+
{
4725+
auto& prev = sub.get_output();
4726+
output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc());
4727+
4728+
check_mask(prev);
4729+
tt::multiply(false, output, prev, binary_mask);
4730+
if (diag_value != 0.0f) tt::add(1, output, 1, output_mask);
4731+
}
4732+
template <typename SUBNET>
4733+
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
4734+
{
4735+
auto& prev_grad = sub.get_gradient_input();
4736+
tt::multiply(true, prev_grad, gradient_input, binary_mask);
4737+
}
4738+
4739+
inline dpoint map_input_to_output(const dpoint& p) const { return p; }
4740+
inline dpoint map_output_to_input(const dpoint& p) const { return p; }
4741+
4742+
const tensor& get_layer_params() const { return params; }
4743+
tensor& get_layer_params() { return params; }
4744+
4745+
friend void serialize(const tril_& item, std::ostream& out)
4746+
{
4747+
serialize("tril_", out);
4748+
serialize(item.diag, out);
4749+
serialize(item.diag_value, out);
4750+
}
4751+
friend void deserialize(tril_& item, std::istream& in)
4752+
{
4753+
std::string version;
4754+
deserialize(version, in);
4755+
if (version != "tril_")
4756+
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_.");
4757+
deserialize(item.diag, in);
4758+
deserialize(item.diag_value, in);
4759+
}
4760+
4761+
friend std::ostream& operator<<(std::ostream& out, const tril_& item)
4762+
{
4763+
out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")";
4764+
return out;
4765+
}
4766+
friend void to_xml(const tril_& item, std::ostream& out)
4767+
{
4768+
out << "<tril diag='" << item.diag << "' diag_value='" << item.diag_value << "'/>\n";
4769+
}
4770+
4771+
private:
4772+
float compute_diag_value() const {
4773+
if (std::is_same<tag_, neg_infinity_tag>::value)
4774+
return -std::numeric_limits<float>::infinity();
4775+
else if (std::is_same<tag_, zero_tag>::value)
4776+
return 0.0f;
4777+
else
4778+
return static_cast<float>(num_) / static_cast<float>(den_);
4779+
}
4780+
4781+
void check_mask(const tensor& t)
4782+
{
4783+
if (!have_same_dimensions(binary_mask, t)) {
4784+
binary_mask.copy_size(t);
4785+
binary_mask = 1;
4786+
if (diag_value != 0.0f) {
4787+
output_mask.copy_size(t);
4788+
output_mask = 0;
4789+
}
4790+
for (long s = 0; s < output_mask.num_samples(); ++s)
4791+
{
4792+
for (long k = 0; k < output_mask.k(); ++k)
4793+
{
4794+
for (long r = 0; r < output_mask.nr(); ++r)
4795+
{
4796+
for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c)
4797+
{
4798+
if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value;
4799+
binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0;
4800+
}
4801+
}
4802+
}
4803+
}
4804+
}
4805+
}
4806+
4807+
template <typename T>
4808+
struct always_false : std::false_type {};
4809+
4810+
resizable_tensor params; // unused
4811+
resizable_tensor binary_mask, output_mask;
4812+
long diag;
4813+
float diag_value;
4814+
};
4815+
4816+
template <typename SUBNET>
4817+
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
4818+
4819+
template <typename SUBNET>
4820+
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
4821+
4822+
template <long diag, long num, long den, typename SUBNET>
4823+
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
4824+
46994825
// ----------------------------------------------------------------------------------------
47004826

47014827
}

dlib/dnn/layers_abstract.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,6 +3711,162 @@ namespace dlib
37113711
template <typename SUBNET>
37123712
using transpose = add_layer<transpose_, SUBNET>;
37133713

3714+
// ----------------------------------------------------------------------------------------
3715+
3716+
struct neg_infinity_tag {};
3717+
struct zero_tag {};
3718+
3719+
template<typename T>
3720+
struct is_special_value : std::false_type {};
3721+
template<>
3722+
struct is_special_value<neg_infinity_tag> : std::true_type {};
3723+
template<>
3724+
struct is_special_value<zero_tag> : std::true_type {};
3725+
3726+
template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
3727+
class tril_
3728+
{
3729+
/*!
3730+
TEMPLATE PARAMETERS
3731+
- diag_: A long integer specifying the diagonal offset.
3732+
- tag_: A type tag specifying special values or void for numeric values.
3733+
- num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void).
3734+
- den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void).
3735+
3736+
REQUIREMENTS
3737+
- diag_ must be an integer.
3738+
- tag_ must be either neg_infinity_tag, zero_tag, or void.
3739+
- If tag_ is void, num_ and den_ are used to compute the diagonal value.
3740+
- If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored.
3741+
3742+
WHAT THIS OBJECT REPRESENTS
3743+
This object implements a layer in a deep neural network that applies a lower triangular mask to
3744+
its input tensor. The mask is defined such that all elements above the specified diagonal are set
3745+
to a given value. The diagonal offset and the mask value are determined by the template parameters.
3746+
3747+
DIAGONAL VALUE DETERMINATION
3748+
- If tag_ is neg_infinity_tag: diagonal value is set to negative infinity.
3749+
- If tag_ is zero_tag: diagonal value is set to zero.
3750+
- If tag_ is void: diagonal value is set to num_ / den_ as a float.
3751+
3752+
DIAGONAL OFFSET
3753+
The diag_ parameter determines the diagonal above which elements are masked:
3754+
- diag_ = 0: main diagonal
3755+
- diag_ > 0: diag_ steps above the main diagonal
3756+
- diag_ < 0: |diag_| steps below the main diagonal
3757+
3758+
EXAMPLE USAGE
3759+
// Create a layer that masks all elements above the main diagonal with -inf
3760+
tril_<0, neg_infinity_tag> layer1;
3761+
3762+
// Create a layer that masks all elements above the main diagonal with 0
3763+
tril_<0, zero_tag> layer2;
3764+
3765+
// Create a layer that masks all elements above the main diagonal with 0.5
3766+
tril_<0, void, 1, 2> layer3;
3767+
3768+
// Create a layer that masks all elements 5 positions above the main diagonal with -inf
3769+
tril_<5, neg_infinity_tag> layer4;
3770+
3771+
// Create a layer that masks all elements 3 positions below the main diagonal with 0.25
3772+
tril_<-3, void, 1, 4> layer5;
3773+
3774+
SERIALIZATION SUPPORT
3775+
This object supports serialization and deserialization via the serialize() and deserialize() functions.
3776+
!*/
3777+
3778+
public:
3779+
tril_() = default;
3780+
/*!
3781+
ensures
3782+
- This object is properly initialized.
3783+
!*/
3784+
3785+
template <typename SUBNET>
3786+
void setup(const SUBNET& sub);
3787+
/*!
3788+
requires
3789+
- SUBNET is a valid network layer type.
3790+
ensures
3791+
- Initializes the mask based on the dimensions of the input tensor from sub.
3792+
!*/
3793+
3794+
template <typename SUBNET>
3795+
void forward(const SUBNET& sub, resizable_tensor& output);
3796+
/*!
3797+
requires
3798+
- SUBNET is a valid network layer type.
3799+
ensures
3800+
- Applies the lower triangular mask to the input tensor from sub and stores the result in output.
3801+
!*/
3802+
3803+
template <typename SUBNET>
3804+
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
3805+
/*!
3806+
requires
3807+
- SUBNET is a valid network layer type.
3808+
ensures
3809+
- Computes the gradient of the loss with respect to the input tensor and stores it in sub.
3810+
!*/
3811+
3812+
inline dpoint map_input_to_output(const dpoint& p) const;
3813+
/*!
3814+
ensures
3815+
- Maps a point from the input tensor to the corresponding point in the output tensor.
3816+
!*/
3817+
3818+
inline dpoint map_output_to_input(const dpoint& p) const;
3819+
/*!
3820+
ensures
3821+
- Maps a point from the output tensor to the corresponding point in the input tensor.
3822+
!*/
3823+
3824+
const tensor& get_layer_params() const;
3825+
/*!
3826+
ensures
3827+
- Returns the parameters of this layer.
3828+
!*/
3829+
3830+
tensor& get_layer_params();
3831+
/*!
3832+
ensures
3833+
- Returns the parameters of this layer.
3834+
!*/
3835+
3836+
friend void serialize(const tril_& item, std::ostream& out);
3837+
/*!
3838+
ensures
3839+
- Serializes the state of this object to the given output stream.
3840+
!*/
3841+
3842+
friend void deserialize(tril_& item, std::istream& in);
3843+
/*!
3844+
ensures
3845+
- Deserializes the state of this object from the given input stream.
3846+
!*/
3847+
3848+
friend std::ostream& operator<<(std::ostream& out, const tril_& item);
3849+
/*!
3850+
ensures
3851+
- Prints a human-readable representation of this object to the given output stream.
3852+
!*/
3853+
3854+
friend void to_xml(const tril_& item, std::ostream& out);
3855+
/*!
3856+
ensures
3857+
- Serializes the state of this object to XML format and writes it to the given output stream.
3858+
!*/
3859+
};
3860+
3861+
template <typename SUBNET>
3862+
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;
3863+
3864+
template <typename SUBNET>
3865+
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;
3866+
3867+
template <long diag, long num, long den, typename SUBNET>
3868+
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;
3869+
37143870
// ----------------------------------------------------------------------------------------
37153871

37163872
}

dlib/dnn/visitors.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,22 @@ namespace dlib
10291029
update(i);
10301030
}
10311031

1032+
template <long diag, typename tag, long num, long den, typename U, typename E>
1033+
void operator()(size_t i, const add_layer<tril_<diag, tag, num, den>, U, E>&)
1034+
{
1035+
start_node(i, "tril");
1036+
out << " | {diag|{" << diag << "}}";
1037+
out << " | {diag_value|{";
1038+
1039+
if (std::is_same<tag, neg_infinity_tag>::value) out << "-inf";
1040+
else if (std::is_same<tag, zero_tag>::value) out << "0";
1041+
else out << static_cast<float>(num) / static_cast<float>(den);
1042+
1043+
out << "}}";
1044+
end_node();
1045+
update(i);
1046+
}
1047+
10321048
template <typename T, typename U, typename E>
10331049
void operator()(size_t i, const add_layer<T, U, E>&)
10341050
{

dlib/test/dnn.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,12 @@ namespace
20232023
auto res = test_layer(l);
20242024
DLIB_TEST_MSG(res, res);
20252025
}
2026+
{
2027+
print_spinner();
2028+
tril_<-5, void, 1, 2> l;
2029+
auto res = test_layer(l);
2030+
DLIB_TEST_MSG(res, res);
2031+
}
20262032
{
20272033
print_spinner();
20282034
extract_<0,2,2,2> l;
@@ -4447,6 +4453,47 @@ namespace
44474453
}
44484454
}
44494455

4456+
// ----------------------------------------------------------------------------------------
4457+
4458+
void test_tril()
4459+
{
4460+
print_spinner();
4461+
using net_type = tag1<tril_mask<tag2<input<matrix<float>>>>>;
4462+
net_type net;
4463+
4464+
// Input tensor
4465+
dlib::rand rnd;
4466+
const int nr = 2, nc = 3;
4467+
constexpr int n_samples = 3, k = 1;
4468+
std::vector<matrix<float>> x(n_samples);
4469+
matrix<float> xtmp(nr, nc);
4470+
for (int ii = 0; ii < n_samples; ++ii) {
4471+
for (int jj = 0; jj < nr; ++jj)
4472+
for (int kk = 0; kk < nc; ++kk)
4473+
xtmp(jj, kk) = rnd.get_random_gaussian();
4474+
x[ii] = xtmp;
4475+
}
4476+
4477+
// Convert input matrix to tensor
4478+
resizable_tensor input_tensor;
4479+
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
4480+
net.forward(input_tensor);
4481+
4482+
// Expected output tensor (manually set for comparison)
4483+
resizable_tensor expected_output;
4484+
expected_output.copy_size(input_tensor);
4485+
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
4486+
for (int ii = 0; ii < n_samples; ++ii) {
4487+
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
4488+
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
4489+
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
4490+
}
4491+
4492+
// Compare output tensor with expected output
4493+
auto& net_output = layer<tag1>(net).get_output();
4494+
DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5);
4495+
}
4496+
44504497
// ----------------------------------------------------------------------------------------
44514498

44524499
class dnn_tester : public tester
@@ -4527,6 +4574,7 @@ namespace
45274574
test_layer_normalize();
45284575
test_rms_normalize();
45294576
test_transpose();
4577+
test_tril();
45304578
test_basic_tensor_ops();
45314579
test_layers();
45324580
test_visit_functions();

0 commit comments

Comments
 (0)