For example, the torch.nn.LSTM module output is in the format torch.Tensor, (torch.Tensor, torch.Tensor) which causes ModuleCompose.debug to crash on print_intermediate when it assumes the second item of the output has a shape because the first item does. A proposed quick fix would be changing line 77
from
if hasattr(x[0], 'shape'):
to
if all(map(lambda x: hasattr(x, 'shape'), x)):
Or alternatively we might want to recurse through the output until a non-tuple is reached?