@@ -4631,6 +4631,131 @@ namespace dlib
46314631 >
46324632 using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
46334633
4634+ // ----------------------------------------------------------------------------------------
4635+
4636+ template <
4637+ long _offset_k,
4638+ long _offset_nr,
4639+ long _offset_nc,
4640+ long _k,
4641+ long _nr,
4642+ long _nc
4643+ >
4644+ class slice_
4645+ {
4646+ static_assert (_offset_k >= 0 , " The channel offset must be >= 0." );
4647+ static_assert (_offset_nr >= 0 , " The row offset must be >= 0." );
4648+ static_assert (_offset_nc >= 0 , " The column offset must be >= 0." );
4649+ static_assert (_k > 0 , " The number of channels must be > 0." );
4650+ static_assert (_nr > 0 , " The number of rows must be > 0." );
4651+ static_assert (_nc > 0 , " The number of columns must be > 0." );
4652+ public:
4653+ slice_ (
4654+ )
4655+ {
4656+ }
4657+
4658+ template <typename SUBNET>
4659+ void setup (const SUBNET& sub)
4660+ {
4661+ DLIB_CASSERT ((long )sub.get_output ().size () >= sub.get_output ().num_samples ()*(_offset_k+_offset_nr+_offset_nc+_k*_nr*_nc),
4662+ " The tensor we are trying to slice from the input tensor is too big to fit into the input tensor." );
4663+ }
4664+
4665+ template <typename SUBNET>
4666+ void forward (const SUBNET& sub, resizable_tensor& output)
4667+ {
4668+ output.set_size (sub.get_output ().num_samples (), _k, _nr, _nc);
4669+ tt::copy_tensor (false , output, 0 , 0 , 0 , sub.get_output (), _offset_k, _offset_nr, _offset_nc, _k, _nr, _nc);
4670+ }
4671+
4672+ template <typename SUBNET>
4673+ void backward (const tensor& gradient_input, SUBNET& sub, tensor& /* params_grad*/ )
4674+ {
4675+ tt::copy_tensor (true , sub.get_gradient_input (), _offset_k, _offset_nr, _offset_nc, gradient_input, 0 , 0 , 0 , _k, _nr, _nc);
4676+ }
4677+
4678+ const tensor& get_layer_params () const { return params; }
4679+ tensor& get_layer_params () { return params; }
4680+
4681+ friend void serialize (const slice_& /* item*/ , std::ostream& out)
4682+ {
4683+ serialize (" slice_" , out);
4684+ serialize (_offset_k, out);
4685+ serialize (_offset_nr, out);
4686+ serialize (_offset_nc, out);
4687+ serialize (_k, out);
4688+ serialize (_nr, out);
4689+ serialize (_nc, out);
4690+ }
4691+
4692+ friend void deserialize (slice_& /* item*/ , std::istream& in)
4693+ {
4694+ std::string version;
4695+ deserialize (version, in);
4696+ if (version != " slice_" )
4697+ throw serialization_error (" Unexpected version '" +version+" ' found while deserializing dlib::slice_." );
4698+
4699+ long offset_k;
4700+ long offset_nr;
4701+ long offset_nc;
4702+ long k;
4703+ long nr;
4704+ long nc;
4705+ deserialize (offset_k, in);
4706+ deserialize (offset_nr, in);
4707+ deserialize (offset_nc, in);
4708+ deserialize (k, in);
4709+ deserialize (nr, in);
4710+ deserialize (nc, in);
4711+
4712+ if (offset_k != _offset_k) throw serialization_error (" Wrong offset_k found while deserializing dlib::slice_" );
4713+ if (offset_nr != _offset_nr) throw serialization_error (" Wrong offset_nr found while deserializing dlib::slice_" );
4714+ if (offset_nc != _offset_nc) throw serialization_error (" Wrong offset_nc found while deserializing dlib::slice_" );
4715+ if (k != _k) throw serialization_error (" Wrong k found while deserializing dlib::slice_" );
4716+ if (nr != _nr) throw serialization_error (" Wrong nr found while deserializing dlib::slice_" );
4717+ if (nc != _nc) throw serialization_error (" Wrong nc found while deserializing dlib::slice_" );
4718+ }
4719+
4720+ friend std::ostream& operator <<(std::ostream& out, const slice_& /* item*/ )
4721+ {
4722+ out << " slice\t ("
4723+ << " offset_k=" <<_offset_k
4724+ << " offset_nr=" <<_offset_nr
4725+ << " offset_nc=" <<_offset_nc
4726+ << " , k=" <<_k
4727+ << " , nr=" <<_nr
4728+ << " , nc=" <<_nc
4729+ << " )" ;
4730+ return out;
4731+ }
4732+
4733+ friend void to_xml (const slice_& /* item*/ , std::ostream& out)
4734+ {
4735+ out << " <slice" ;
4736+ out << " offset_k='" <<_offset_k<<" '" ;
4737+ out << " offset_nr='" <<_offset_nr<<" '" ;
4738+ out << " offset_nr='" <<_offset_nc<<" '" ;
4739+ out << " k='" <<_k<<" '" ;
4740+ out << " nr='" <<_nr<<" '" ;
4741+ out << " nc='" <<_nc<<" '" ;
4742+ out << " />\n " ;
4743+ }
4744+ private:
4745+ resizable_tensor params; // unused
4746+ };
4747+
4748+ template <
4749+ long offset_k,
4750+ long offset_nr,
4751+ long offset_nc,
4752+ long k,
4753+ long nr,
4754+ long nc,
4755+ typename SUBNET
4756+ >
4757+ using slice = add_layer<slice_<offset_k,offset_nr,offset_nc,k,nr,nc>, SUBNET>;
4758+
46344759// ----------------------------------------------------------------------------------------
46354760
46364761 template <long long row_stride = 2 , long long col_stride = 2 >
0 commit comments