For example, looking at SubModule2D:
def output_shape(self, dim=None):
if dim == 1:
f, l = self._output_shape
return (f * l,)
...
...
def _to_1d(self, submodule_output):
"""
:param submodule_output: torch.Tensor (Batch + 2D)
:return: torch.Tensor (Batch + 1D)
"""
n, f, l = submodule_output.size()
return submodule_output.view(n, f * l)
The fact that the 2D -> 1D conversion goes from (F, S) -> (F * S) is indicated in two places within the SubModule2D class. The same is true in general for mD -> nD. It may be worth eliminating this duplication.