1313# limitations under the License.
1414"""The LightningModule - an nn.Module with many additional features."""
1515
16- import collections
16+ import collections . abc
1717import inspect
1818import logging
1919import numbers
@@ -1712,7 +1712,7 @@ def tbptt_split_batch(self, batch, split_size):
17121712 for i, x in enumerate(batch):
17131713 if isinstance(x, torch.Tensor):
17141714 split_x = x[:, t:t + split_size]
1715- elif isinstance(x, collections.Sequence):
1715+ elif isinstance(x, collections.abc. Sequence):
17161716 split_x = [None] * len(x)
17171717 for batch_idx in range(len(x)):
17181718 split_x[batch_idx] = x[batch_idx][t:t + split_size]
@@ -1726,7 +1726,7 @@ def tbptt_split_batch(self, batch, split_size):
17261726 if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
17271727 Each returned batch split is passed separately to :meth:`training_step`.
17281728 """
1729- time_dims = [len (x [0 ]) for x in batch if isinstance (x , (Tensor , collections .Sequence ))]
1729+ time_dims = [len (x [0 ]) for x in batch if isinstance (x , (Tensor , collections .abc . Sequence ))]
17301730 assert len (time_dims ) >= 1 , "Unable to determine batch time dimension"
17311731 assert all (x == time_dims [0 ] for x in time_dims ), "Batch time dimension length is ambiguous"
17321732
@@ -1736,7 +1736,7 @@ def tbptt_split_batch(self, batch, split_size):
17361736 for i , x in enumerate (batch ):
17371737 if isinstance (x , Tensor ):
17381738 split_x = x [:, t : t + split_size ]
1739- elif isinstance (x , collections .Sequence ):
1739+ elif isinstance (x , collections .abc . Sequence ):
17401740 split_x = [None ] * len (x )
17411741 for batch_idx in range (len (x )):
17421742 split_x [batch_idx ] = x [batch_idx ][t : t + split_size ]
0 commit comments