@@ -3884,7 +3884,162 @@ namespace dlib
38843884
38853885// ----------------------------------------------------------------------------------------
38863886
3887- }
3887+ struct neg_infinity_tag {};
3888+ struct zero_tag {};
3889+
3890+ template <typename T>
3891+ struct is_special_value : std::false_type {};
3892+ template <>
3893+ struct is_special_value <neg_infinity_tag> : std::true_type {};
3894+ template <>
3895+ struct is_special_value <zero_tag> : std::true_type {};
3896+
3897+ template <long diag_, typename tag_, long num_ = 0 , long den_ = 1 >
3898+ class tril_
3899+ {
3900+ /* !
3901+ TEMPLATE PARAMETERS
3902+ - diag_: A long integer specifying the diagonal offset.
3903+ - tag_: A type tag specifying special values or void for numeric values.
3904+ - num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void).
3905+ - den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void).
3906+
3907+ REQUIREMENTS
3908+ - diag_ must be an integer.
3909+ - tag_ must be either neg_infinity_tag, zero_tag, or void.
3910+ - If tag_ is void, num_ and den_ are used to compute the diagonal value.
3911+ - If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored.
3912+
3913+ WHAT THIS OBJECT REPRESENTS
3914+ This object implements a layer in a deep neural network that applies a lower triangular mask to
3915+ its input tensor. The mask is defined such that all elements above the specified diagonal are set
3916+ to a given value. The diagonal offset and the mask value are determined by the template parameters.
3917+
3918+ DIAGONAL VALUE DETERMINATION
3919+ - If tag_ is neg_infinity_tag: diagonal value is set to negative infinity.
3920+ - If tag_ is zero_tag: diagonal value is set to zero.
3921+ - If tag_ is void: diagonal value is set to num_ / den_ as a float.
3922+
3923+ DIAGONAL OFFSET
3924+ The diag_ parameter determines the diagonal above which elements are masked:
3925+ - diag_ = 0: main diagonal
3926+ - diag_ > 0: diag_ steps above the main diagonal
3927+ - diag_ < 0: |diag_| steps below the main diagonal
3928+
3929+ EXAMPLE USAGE
3930+ // Create a layer that masks all elements above the main diagonal with -inf
3931+ tril_<0, neg_infinity_tag> layer1;
3932+
3933+ // Create a layer that masks all elements above the main diagonal with 0
3934+ tril_<0, zero_tag> layer2;
3935+
3936+ // Create a layer that masks all elements above the main diagonal with 0.5
3937+ tril_<0, void, 1, 2> layer3;
3938+
3939+ // Create a layer that masks all elements 5 positions above the main diagonal with -inf
3940+ tril_<5, neg_infinity_tag> layer4;
3941+
3942+ // Create a layer that masks all elements 3 positions below the main diagonal with 0.25
3943+ tril_<-3, void, 1, 4> layer5;
3944+
3945+ SERIALIZATION SUPPORT
3946+ This object supports serialization and deserialization via the serialize() and deserialize() functions.
3947+ !*/
3948+
3949+ public:
3950+ tril_ () = default ;
3951+ /* !
3952+ ensures
3953+ - This object is properly initialized.
3954+ !*/
3955+
3956+ template <typename SUBNET>
3957+ void setup (const SUBNET& sub);
3958+ /* !
3959+ requires
3960+ - SUBNET is a valid network layer type.
3961+ ensures
3962+ - Initializes the mask based on the dimensions of the input tensor from sub.
3963+ !*/
3964+
3965+ template <typename SUBNET>
3966+ void forward (const SUBNET& sub, resizable_tensor& output);
3967+ /* !
3968+ requires
3969+ - SUBNET is a valid network layer type.
3970+ ensures
3971+ - Applies the lower triangular mask to the input tensor from sub and stores the result in output.
3972+ !*/
3973+
3974+ template <typename SUBNET>
3975+ void backward (const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
3976+ /* !
3977+ requires
3978+ - SUBNET is a valid network layer type.
3979+ ensures
3980+ - Computes the gradient of the loss with respect to the input tensor and stores it in sub.
3981+ !*/
3982+
3983+ inline dpoint map_input_to_output (const dpoint& p) const ;
3984+ /* !
3985+ ensures
3986+ - Maps a point from the input tensor to the corresponding point in the output tensor.
3987+ !*/
3988+
3989+ inline dpoint map_output_to_input (const dpoint& p) const ;
3990+ /* !
3991+ ensures
3992+ - Maps a point from the output tensor to the corresponding point in the input tensor.
3993+ !*/
3994+
3995+ const tensor& get_layer_params () const ;
3996+ /* !
3997+ ensures
3998+ - Returns the parameters of this layer.
3999+ !*/
4000+
4001+ tensor& get_layer_params ();
4002+ /* !
4003+ ensures
4004+ - Returns the parameters of this layer.
4005+ !*/
4006+
4007+ friend void serialize (const tril_& item, std::ostream& out);
4008+ /* !
4009+ ensures
4010+ - Serializes the state of this object to the given output stream.
4011+ !*/
4012+
4013+ friend void deserialize (tril_& item, std::istream& in);
4014+ /* !
4015+ ensures
4016+ - Deserializes the state of this object from the given input stream.
4017+ !*/
4018+
4019+ friend std::ostream& operator <<(std::ostream& out, const tril_& item);
4020+ /* !
4021+ ensures
4022+ - Prints a human-readable representation of this object to the given output stream.
4023+ !*/
4024+
4025+ friend void to_xml (const tril_& item, std::ostream& out);
4026+ /* !
4027+ ensures
4028+ - Serializes the state of this object to XML format and writes it to the given output stream.
4029+ !*/
4030+ };
4031+
4032+ template <typename SUBNET>
4033+ using tril = add_layer<tril_<0 , zero_tag>, SUBNET>;
4034+
4035+ template <typename SUBNET>
4036+ using tril_mask = add_layer<tril_<0 , neg_infinity_tag>, SUBNET>;
38884037
3889- #endif // DLIB_DNn_LAYERS_ABSTRACT_H_
4038+ template <long diag, long num, long den, typename SUBNET>
4039+ using tril_diag = add_layer<tril_<diag, void , num, den>, SUBNET>;
4040+
4041+ // ----------------------------------------------------------------------------------------
4042+
4043+ }
38904044
4045+ #endif // DLIB_DNn_LAYERS_ABSTRACT_H_
0 commit comments