3
3
import math
4
4
from dataclasses import dataclass
5
5
from enum import Enum
6
- from typing import TYPE_CHECKING , List , Optional , Union
6
+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union
7
7
8
8
import torch
9
9
@@ -231,7 +231,7 @@ def _concat_conditionings_for_batch(
231
231
conditionings : List [torch .Tensor ],
232
232
) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
233
233
"""Concatenate provided conditioning tensors to one batched tensor.
234
- If tensors have different sizes then pad them by zeros and creates
234
+ If tensors have different sizes then pad them by zeros and creates
235
235
encoder_attention_mask to exclude padding from attention.
236
236
237
237
Args:
@@ -242,9 +242,7 @@ def _concat_conditionings_for_batch(
242
242
if any (c .shape [1 ] != max_len for c in conditionings ):
243
243
encoder_attention_masks = [None ] * len (conditionings )
244
244
for i in range (len (conditionings )):
245
- conditionings [i ], encoder_attention_masks [i ] = cls ._pad_conditioning (
246
- conditionings [i ], max_len
247
- )
245
+ conditionings [i ], encoder_attention_masks [i ] = cls ._pad_conditioning (conditionings [i ], max_len )
248
246
encoder_attention_mask = torch .cat (encoder_attention_masks )
249
247
250
248
return torch .cat (conditionings ), encoder_attention_mask
0 commit comments