Skip to content

Commit 66be36e

Browse files
CydraldaviskingCopilot
authored
Add reshape_to layer for flexible tensor reshaping/rescaling (#3076)
* Implementation of linear_ layer for neural networks. This layer provides an optimized linear transformation for multi-dimensional inputs. * Minor change * Update dlib/dnn/layers.h Co-authored-by: Copilot <[email protected]> * Add reshape_to and flatten layers to Dlib's DNN module * Missing update to "visitors.h" * format fixing for reshape_to * Update dlib/test/dnn.cpp --------- Co-authored-by: Davis E. King <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 5caf80f commit 66be36e

File tree

4 files changed

+420
-1
lines changed

4 files changed

+420
-1
lines changed

dlib/dnn/layers.h

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,186 @@ namespace dlib
975975
>
976976
using resize_to = add_layer<resize_to_<NR,NC>, SUBNET>;
977977

978+
// ----------------------------------------------------------------------------------------
979+
980+
template <long k_ = -1, long nr_ = -1, long nc_ = -1>
981+
class reshape_to_
982+
{
983+
public:
984+
explicit reshape_to_() :
985+
output_k(k_),
986+
output_nr(nr_),
987+
output_nc(nc_)
988+
{
989+
static_assert(k_ == -1 || k_ > 0, "Output k must be positive or -1");
990+
static_assert(nr_ == -1 || nr_ > 0, "Output nr must be positive or -1");
991+
static_assert(nc_ == -1 || nc_ > 0, "Output nc must be positive or -1");
992+
993+
input_k = input_nr = input_nc = 0;
994+
needs_rescale = false;
995+
}
996+
997+
// Getters for dimensions
998+
long get_output_k() const { return output_k; }
999+
long get_output_nr() const { return output_nr; }
1000+
long get_output_nc() const { return output_nc; }
1001+
1002+
// Setters for dimensions
1003+
void set_output_k(long k) {
1004+
DLIB_CASSERT(k == -1 || k > 0, "Output k must be positive or -1 to keep original dimension");
1005+
output_k = k;
1006+
}
1007+
void set_output_nr(long nr) {
1008+
DLIB_CASSERT(nr == -1 || nr > 0, "output nr must be positive or -1 to keep original dimension");
1009+
output_nr = nr;
1010+
}
1011+
void set_output_nc(long nc) {
1012+
DLIB_CASSERT(nc == -1 || nc > 0, "output nc must be positive or -1 to keep original dimension");
1013+
output_nc = nc;
1014+
}
1015+
1016+
template <typename SUBNET>
1017+
void setup(const SUBNET& sub)
1018+
{
1019+
const auto& input = sub.get_output();
1020+
input_k = input.k();
1021+
input_nr = input.nr();
1022+
input_nc = input.nc();
1023+
1024+
// Calculate output dimensions using input dims where target is -1
1025+
if (k_ == -1) output_k = input_k;
1026+
if (nr_ == -1) output_nr = input_nr;
1027+
if (nc_ == -1) output_nc = input_nc;
1028+
1029+
// Check if this is well a pure reshape
1030+
long input_elements = input_k * input_nr * input_nc;
1031+
long output_elements = output_k * output_nr * output_nc;
1032+
if (input_elements != output_elements && input_k == output_k) needs_rescale = true;
1033+
DLIB_CASSERT(input_elements == output_elements || needs_rescale,
1034+
"Cannot reshape tensor of " << input_elements <<
1035+
" elements into shape with " << output_elements << " elements. " <<
1036+
"For spatial rescaling, the channel dimension (k) must remain constant.");
1037+
}
1038+
1039+
template <typename SUBNET>
1040+
void forward(const SUBNET& sub, resizable_tensor& output)
1041+
{
1042+
// Set the output size (always preserving batch dimension)
1043+
const tensor& input = sub.get_output();
1044+
output.set_size(input.num_samples(), output_k, output_nr, output_nc);
1045+
1046+
if (!needs_rescale)
1047+
{
1048+
// Create an alias of the input tensor with the output shape
1049+
alias_tensor input_alias(output.num_samples(), output_k, output_nr, output_nc);
1050+
// Get a view of the input tensor with the new shape
1051+
auto input_reshaped = input_alias(const_cast<tensor&>(input), 0);
1052+
// Copy the view to the output tensor
1053+
tt::copy_tensor(false, output, 0, input_reshaped, 0, input_reshaped.k());
1054+
}
1055+
else
1056+
{
1057+
// Only spatial dimensions need to be resized
1058+
tt::resize_bilinear(output, input);
1059+
}
1060+
}
1061+
1062+
template <typename SUBNET>
1063+
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
1064+
{
1065+
auto& grad = sub.get_gradient_input();
1066+
1067+
if (!needs_rescale) {
1068+
// Create an alias of the gradient tensor with the original input shape
1069+
alias_tensor grad_alias(grad.num_samples(), grad.k(), grad.nr(), grad.nc());
1070+
// Get a view of the input gradient with the required shape
1071+
auto grad_reshaped = grad_alias(const_cast<tensor&>(gradient_input), 0);
1072+
// Copy the view to the output gradient
1073+
tt::copy_tensor(true, grad, 0, grad_reshaped, 0, grad_reshaped.k());
1074+
}
1075+
else
1076+
{
1077+
// Only spatial dimensions were resized
1078+
tt::resize_bilinear_gradient(grad, gradient_input);
1079+
}
1080+
}
1081+
1082+
// Mapping functions for coordinate transformations
1083+
inline dpoint map_input_to_output(const dpoint& p) const {
1084+
double scale_x = output_nc / static_cast<double>(input_nc);
1085+
double scale_y = output_nr / static_cast<double>(input_nr);
1086+
return dpoint(p.x() * scale_x, p.y() * scale_y);
1087+
}
1088+
inline dpoint map_output_to_input(const dpoint& p) const {
1089+
double scale_x = input_nc / static_cast<double>(output_nc);
1090+
double scale_y = input_nr / static_cast<double>(output_nr);
1091+
return dpoint(p.x() * scale_x, p.y() * scale_y);
1092+
}
1093+
1094+
const tensor& get_layer_params() const { return params; }
1095+
tensor& get_layer_params() { return params; }
1096+
1097+
friend void serialize(const reshape_to_& item, std::ostream& out)
1098+
{
1099+
serialize("reshape_to_", out);
1100+
serialize(item.input_k, out);
1101+
serialize(item.input_nr, out);
1102+
serialize(item.input_nc, out);
1103+
serialize(item.output_k, out);
1104+
serialize(item.output_nr, out);
1105+
serialize(item.output_nc, out);
1106+
serialize(item.needs_rescale, out);
1107+
}
1108+
1109+
friend void deserialize(reshape_to_& item, std::istream& in)
1110+
{
1111+
std::string version;
1112+
deserialize(version, in);
1113+
if (version != "reshape_to_")
1114+
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::reshape_to_.");
1115+
deserialize(item.input_k, in);
1116+
deserialize(item.input_nr, in);
1117+
deserialize(item.input_nc, in);
1118+
deserialize(item.output_k, in);
1119+
deserialize(item.output_nr, in);
1120+
deserialize(item.output_nc, in);
1121+
deserialize(item.needs_rescale, in);
1122+
}
1123+
1124+
friend std::ostream& operator<<(std::ostream& out, const reshape_to_& item)
1125+
{
1126+
out << "reshape_to (";
1127+
out << "k=" << std::to_string(item.output_k);
1128+
out << ", nr=" << std::to_string(item.output_nr);
1129+
out << ", nc=" << std::to_string(item.output_nc);
1130+
out << ", mode=" << (item.needs_rescale ? "spatial_rescale" : "pure_reshape");
1131+
out << ")";
1132+
return out;
1133+
}
1134+
1135+
friend void to_xml(const reshape_to_& item, std::ostream& out)
1136+
{
1137+
out << "<reshape_to"
1138+
<< " k='" << item.output_k << "'"
1139+
<< " nr='" << item.output_nr << "'"
1140+
<< " nc='" << item.output_nc << "'"
1141+
<< " mode='" << (item.needs_rescale ? "spatial_rescale" : "pure_reshape") << "'"
1142+
<< "/>\n";
1143+
}
1144+
1145+
private:
1146+
long input_k, input_nr, input_nc; // Input dimensions
1147+
long output_k, output_nr, output_nc; // Output dimensions
1148+
bool needs_rescale;
1149+
resizable_tensor params; // No trainable parameters
1150+
};
1151+
1152+
template <long k, long nr, long nc, typename SUBNET>
1153+
using reshape_to = add_layer<reshape_to_<k, nr, nc>, SUBNET>;
1154+
1155+
template <long k, long nr, long nc, typename SUBNET>
1156+
using flatten = add_layer<reshape_to_<k * nr * nc, 1, 1>, SUBNET>;
1157+
9781158
// ----------------------------------------------------------------------------------------
9791159

9801160
template <

dlib/dnn/layers_abstract.h

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,176 @@ namespace dlib
16421642
>
16431643
using resize_to = add_layer<resize_to_<NR,NC>, SUBNET>;
16441644

1645+
// ----------------------------------------------------------------------------------------
1646+
1647+
template <long k_ = -1, long nr_ = -1, long nc_ = -1>
1648+
class reshape_to_
1649+
{
1650+
/*!
1651+
REQUIREMENTS ON TEMPLATE ARGUMENTS
1652+
- k_, nr_, and nc_ must be either -1 or greater than 0.
1653+
1654+
WHAT THIS OBJECT REPRESENTS
1655+
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
1656+
defined above. It defines a layer that reshapes or resizes an input tensor
1657+
into a different shape. The layer operates in two modes:
1658+
1659+
1. Pure Reshape Mode: When the total number of elements in the input tensor
1660+
equals the total number of elements in the output tensor, this layer
1661+
performs a simple reshaping operation without changing the values.
1662+
1663+
2. Spatial Rescaling Mode: When the channel dimension (k) remains constant
1664+
but the total number of elements changes, this layer performs bilinear
1665+
interpolation to resize the spatial dimensions while preserving the
1666+
channel information.
1667+
1668+
The dimensions of the output tensor are determined by the template parameters:
1669+
- If k_ is -1, the output tensor will have the same number of channels as the input.
1670+
- If nr_ is -1, the output tensor will have the same number of rows as the input.
1671+
- If nc_ is -1, the output tensor will have the same number of columns as the input.
1672+
1673+
Setting a value of -1 for any dimension means "keep the original dimension from the input."
1674+
1675+
Note that this layer will throw an exception if you attempt to change both the
1676+
channel count (k) and the total number of elements. Either:
1677+
- Keep the total number of elements the same (Pure Reshape Mode), or
1678+
- Keep the channel count the same and only change spatial dimensions (Spatial Rescaling Mode)
1679+
!*/
1680+
1681+
public:
1682+
explicit reshape_to_();
1683+
/*!
1684+
ensures
1685+
- #get_output_k() == k_
1686+
- #get_output_nr() == nr_
1687+
- #get_output_nc() == nc_
1688+
!*/
1689+
1690+
long get_output_k() const;
1691+
/*!
1692+
ensures
1693+
- Returns the number of channels in the output tensor. If this value is -1,
1694+
then the output will have the same number of channels as the input.
1695+
!*/
1696+
1697+
long get_output_nr() const;
1698+
/*!
1699+
ensures
1700+
- Returns the number of rows in the output tensor. If this value is -1,
1701+
then the output will have the same number of rows as the input.
1702+
!*/
1703+
1704+
long get_output_nc() const;
1705+
/*!
1706+
ensures
1707+
- Returns the number of columns in the output tensor. If this value is -1,
1708+
then the output will have the same number of columns as the input.
1709+
!*/
1710+
1711+
void set_output_k(long k);
1712+
/*!
1713+
requires
1714+
- k == -1 || k > 0
1715+
ensures
1716+
- #get_output_k() == k
1717+
!*/
1718+
1719+
void set_output_nr(long nr);
1720+
/*!
1721+
requires
1722+
- nr == -1 || nr > 0
1723+
ensures
1724+
- #get_output_nr() == nr
1725+
!*/
1726+
1727+
void set_output_nc(long nc);
1728+
/*!
1729+
requires
1730+
- nc == -1 || nc > 0
1731+
ensures
1732+
- #get_output_nc() == nc
1733+
!*/
1734+
1735+
template <typename SUBNET> void setup(const SUBNET& sub);
1736+
/*!
1737+
requires
1738+
- SUBNET implements the SUBNET interface defined at the top of this file.
1739+
ensures
1740+
- Configures this layer to operate on the output of sub.
1741+
- If the total number of elements in the input tensor doesn't match the total
1742+
number of elements in the output tensor and the channel dimension is different,
1743+
an exception will be thrown.
1744+
!*/
1745+
1746+
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
1747+
/*!
1748+
requires
1749+
- SUBNET implements the SUBNET interface defined at the top of this file.
1750+
- setup() has been called.
1751+
ensures
1752+
- Reshapes or resizes the output of sub and stores it in #output.
1753+
- If is_spatial_rescale() == false, then performs a pure reshape operation.
1754+
- If is_spatial_rescale() == true, then performs bilinear interpolation to resize
1755+
the spatial dimensions while preserving the channel information.
1756+
- #output.num_samples() == sub.get_output().num_samples()
1757+
- #output.k() == get_output_k() if get_output_k() != -1, otherwise sub.get_output().k()
1758+
- #output.nr() == get_output_nr() if get_output_nr() != -1, otherwise sub.get_output().nr()
1759+
- #output.nc() == get_output_nc() if get_output_nc() != -1, otherwise sub.get_output().nc()
1760+
!*/
1761+
1762+
template <typename SUBNET> void backward(
1763+
const tensor& gradient_input,
1764+
SUBNET& sub,
1765+
tensor& params_grad
1766+
);
1767+
/*!
1768+
requires
1769+
- SUBNET implements the SUBNET interface defined at the top of this file.
1770+
- setup() has been called.
1771+
- gradient_input has the same dimensions as the output of forward().
1772+
ensures
1773+
- Computes the gradients of this layer with respect to the input tensor and
1774+
parameters, and stores them in sub.get_gradient_input() and params_grad,
1775+
respectively.
1776+
- This function supports both pure reshaping and spatial rescaling operations.
1777+
!*/
1778+
1779+
dpoint map_input_to_output(dpoint p) const;
1780+
/*!
1781+
ensures
1782+
- Maps a point in the input tensor's coordinate system to the corresponding point
1783+
in the output tensor. This is useful for tracking how spatial locations change
1784+
through the network, especially during spatial rescaling.
1785+
!*/
1786+
1787+
dpoint map_output_to_input(dpoint p) const;
1788+
/*!
1789+
ensures
1790+
- Maps a point in the output tensor's coordinate system to the corresponding point
1791+
in the input tensor. This is the inverse of map_input_to_output().
1792+
!*/
1793+
1794+
const tensor& get_layer_params() const;
1795+
/*!
1796+
ensures
1797+
- Returns the layer's parameters. This layer has no parameters,
1798+
so this always returns an empty tensor.
1799+
!*/
1800+
1801+
tensor& get_layer_params();
1802+
/*!
1803+
ensures
1804+
- Returns the layer's parameters. This layer has no parameters,
1805+
so this always returns an empty tensor.
1806+
!*/
1807+
};
1808+
1809+
template <long k, long nr, long nc, typename SUBNET>
1810+
using reshape_to = add_layer<reshape_to_<k, nr, nc>, SUBNET>;
1811+
1812+
template <long k, long nr, long nc, typename SUBNET>
1813+
using flatten = add_layer<reshape_to_<k * nr, * nc, 1, 1>, SUBNET>;
1814+
16451815
// ----------------------------------------------------------------------------------------
16461816

16471817
class dropout_

0 commit comments

Comments
 (0)