Skip to content

In SubModuleXD, shapes are duplicated within output_shape and _to_xd #81

@shankstm

Description

@shankstm

For example, looking at SubModule2D:

    def output_shape(self, dim=None):
        if dim == 1:
            f, l = self._output_shape
            return (f * l,)
        ...
    ...
    def _to_1d(self, submodule_output):
        """
        :param submodule_output: torch.Tensor (Batch + 2D)
        :return: torch.Tensor (Batch + 1D)
        """
        n, f, l = submodule_output.size()
        return submodule_output.view(n, f * l)

The fact that the 2D -> 1D conversion goes from (F, S) -> (F * S) is indicated in two places within the SubModule2D class. The same is true in general for mD -> nD. It may be worth eliminating this duplication.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions