@@ -975,6 +975,186 @@ namespace dlib
975975 >
976976 using resize_to = add_layer<resize_to_<NR,NC>, SUBNET>;
977977
978+ // ----------------------------------------------------------------------------------------
979+
980+ template <long k_ = -1 , long nr_ = -1 , long nc_ = -1 >
981+ class reshape_to_
982+ {
983+ public:
984+ explicit reshape_to_ () :
985+ output_k(k_),
986+ output_nr(nr_),
987+ output_nc(nc_)
988+ {
989+ static_assert (k_ == -1 || k_ > 0 , " Output k must be positive or -1" );
990+ static_assert (nr_ == -1 || nr_ > 0 , " Output nr must be positive or -1" );
991+ static_assert (nc_ == -1 || nc_ > 0 , " Output nc must be positive or -1" );
992+
993+ input_k = input_nr = input_nc = 0 ;
994+ needs_rescale = false ;
995+ }
996+
997+ // Getters for dimensions
998+ long get_output_k () const { return output_k; }
999+ long get_output_nr () const { return output_nr; }
1000+ long get_output_nc () const { return output_nc; }
1001+
1002+ // Setters for dimensions
1003+ void set_output_k (long k) {
1004+ DLIB_CASSERT (k == -1 || k > 0 , " Output k must be positive or -1 to keep original dimension" );
1005+ output_k = k;
1006+ }
1007+ void set_output_nr (long nr) {
1008+ DLIB_CASSERT (nr == -1 || nr > 0 , " output nr must be positive or -1 to keep original dimension" );
1009+ output_nr = nr;
1010+ }
1011+ void set_output_nc (long nc) {
1012+ DLIB_CASSERT (nc == -1 || nc > 0 , " output nc must be positive or -1 to keep original dimension" );
1013+ output_nc = nc;
1014+ }
1015+
1016+ template <typename SUBNET>
1017+ void setup (const SUBNET& sub)
1018+ {
1019+ const auto & input = sub.get_output ();
1020+ input_k = input.k ();
1021+ input_nr = input.nr ();
1022+ input_nc = input.nc ();
1023+
1024+ // Calculate output dimensions using input dims where target is -1
1025+ if (k_ == -1 ) output_k = input_k;
1026+ if (nr_ == -1 ) output_nr = input_nr;
1027+ if (nc_ == -1 ) output_nc = input_nc;
1028+
1029+ // Check if this is well a pure reshape
1030+ long input_elements = input_k * input_nr * input_nc;
1031+ long output_elements = output_k * output_nr * output_nc;
1032+ if (input_elements != output_elements && input_k == output_k) needs_rescale = true ;
1033+ DLIB_CASSERT (input_elements == output_elements || needs_rescale,
1034+ " Cannot reshape tensor of " << input_elements <<
1035+ " elements into shape with " << output_elements << " elements. " <<
1036+ " For spatial rescaling, the channel dimension (k) must remain constant." );
1037+ }
1038+
1039+ template <typename SUBNET>
1040+ void forward (const SUBNET& sub, resizable_tensor& output)
1041+ {
1042+ // Set the output size (always preserving batch dimension)
1043+ const tensor& input = sub.get_output ();
1044+ output.set_size (input.num_samples (), output_k, output_nr, output_nc);
1045+
1046+ if (!needs_rescale)
1047+ {
1048+ // Create an alias of the input tensor with the output shape
1049+ alias_tensor input_alias (output.num_samples (), output_k, output_nr, output_nc);
1050+ // Get a view of the input tensor with the new shape
1051+ auto input_reshaped = input_alias (const_cast <tensor&>(input), 0 );
1052+ // Copy the view to the output tensor
1053+ tt::copy_tensor (false , output, 0 , input_reshaped, 0 , input_reshaped.k ());
1054+ }
1055+ else
1056+ {
1057+ // Only spatial dimensions need to be resized
1058+ tt::resize_bilinear (output, input);
1059+ }
1060+ }
1061+
1062+ template <typename SUBNET>
1063+ void backward (const tensor& gradient_input, SUBNET& sub, tensor& /* params_grad*/ )
1064+ {
1065+ auto & grad = sub.get_gradient_input ();
1066+
1067+ if (!needs_rescale) {
1068+ // Create an alias of the gradient tensor with the original input shape
1069+ alias_tensor grad_alias (grad.num_samples (), grad.k (), grad.nr (), grad.nc ());
1070+ // Get a view of the input gradient with the required shape
1071+ auto grad_reshaped = grad_alias (const_cast <tensor&>(gradient_input), 0 );
1072+ // Copy the view to the output gradient
1073+ tt::copy_tensor (true , grad, 0 , grad_reshaped, 0 , grad_reshaped.k ());
1074+ }
1075+ else
1076+ {
1077+ // Only spatial dimensions were resized
1078+ tt::resize_bilinear_gradient (grad, gradient_input);
1079+ }
1080+ }
1081+
1082+ // Mapping functions for coordinate transformations
1083+ inline dpoint map_input_to_output (const dpoint& p) const {
1084+ double scale_x = output_nc / static_cast <double >(input_nc);
1085+ double scale_y = output_nr / static_cast <double >(input_nr);
1086+ return dpoint (p.x () * scale_x, p.y () * scale_y);
1087+ }
1088+ inline dpoint map_output_to_input (const dpoint& p) const {
1089+ double scale_x = input_nc / static_cast <double >(output_nc);
1090+ double scale_y = input_nr / static_cast <double >(output_nr);
1091+ return dpoint (p.x () * scale_x, p.y () * scale_y);
1092+ }
1093+
1094+ const tensor& get_layer_params () const { return params; }
1095+ tensor& get_layer_params () { return params; }
1096+
1097+ friend void serialize (const reshape_to_& item, std::ostream& out)
1098+ {
1099+ serialize (" reshape_to_" , out);
1100+ serialize (item.input_k , out);
1101+ serialize (item.input_nr , out);
1102+ serialize (item.input_nc , out);
1103+ serialize (item.output_k , out);
1104+ serialize (item.output_nr , out);
1105+ serialize (item.output_nc , out);
1106+ serialize (item.needs_rescale , out);
1107+ }
1108+
1109+ friend void deserialize (reshape_to_& item, std::istream& in)
1110+ {
1111+ std::string version;
1112+ deserialize (version, in);
1113+ if (version != " reshape_to_" )
1114+ throw serialization_error (" Unexpected version '" + version + " ' found while deserializing dlib::reshape_to_." );
1115+ deserialize (item.input_k , in);
1116+ deserialize (item.input_nr , in);
1117+ deserialize (item.input_nc , in);
1118+ deserialize (item.output_k , in);
1119+ deserialize (item.output_nr , in);
1120+ deserialize (item.output_nc , in);
1121+ deserialize (item.needs_rescale , in);
1122+ }
1123+
1124+ friend std::ostream& operator <<(std::ostream& out, const reshape_to_& item)
1125+ {
1126+ out << " reshape_to (" ;
1127+ out << " k=" << std::to_string (item.output_k );
1128+ out << " , nr=" << std::to_string (item.output_nr );
1129+ out << " , nc=" << std::to_string (item.output_nc );
1130+ out << " , mode=" << (item.needs_rescale ? " spatial_rescale" : " pure_reshape" );
1131+ out << " )" ;
1132+ return out;
1133+ }
1134+
1135+ friend void to_xml (const reshape_to_& item, std::ostream& out)
1136+ {
1137+ out << " <reshape_to"
1138+ << " k='" << item.output_k << " '"
1139+ << " nr='" << item.output_nr << " '"
1140+ << " nc='" << item.output_nc << " '"
1141+ << " mode='" << (item.needs_rescale ? " spatial_rescale" : " pure_reshape" ) << " '"
1142+ << " />\n " ;
1143+ }
1144+
1145+ private:
1146+ long input_k, input_nr, input_nc; // Input dimensions
1147+ long output_k, output_nr, output_nc; // Output dimensions
1148+ bool needs_rescale;
1149+ resizable_tensor params; // No trainable parameters
1150+ };
1151+
1152+ template <long k, long nr, long nc, typename SUBNET>
1153+ using reshape_to = add_layer<reshape_to_<k, nr, nc>, SUBNET>;
1154+
1155+ template <long k, long nr, long nc, typename SUBNET>
1156+ using flatten = add_layer<reshape_to_<k * nr * nc, 1 , 1 >, SUBNET>;
1157+
9781158// ----------------------------------------------------------------------------------------
9791159
9801160 template <
0 commit comments