diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 810d1f51fc..f7f1bb864b 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -817,7 +817,7 @@ def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int: def mask_along_axis_iid( specgrams: Tensor, mask_param: int, - mask_value: float, + mask_value: Union[float, Tensor], axis: int, p: float = 1.0, ) -> Tensor: @@ -874,7 +874,12 @@ def mask_along_axis_iid( # Per batch example masking specgrams = specgrams.transpose(axis, -1) - specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value) + # this aims to avoid CPU-GPU sync from upstream + specgrams = ( + torch.where((mask >= mask_start) & (mask < mask_end), mask_value.repeat(*specgrams.shape), specgrams) + if isinstance(mask_value, Tensor) + else specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value) + ) specgrams = specgrams.transpose(axis, -1) return specgrams diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 5bf914bc12..f208de13ae 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -1185,7 +1185,7 @@ def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) self.iid_masks = iid_masks self.p = p - def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor: + def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0) -> Tensor: r""" Args: specgram (Tensor): Tensor of dimension `(..., freq, time)`. diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index b08b63256c..e2e42f1fe7 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -456,6 +456,20 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis, p): assert mask_specgrams.size() == specgrams.size() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() + @parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3], [0.2, 1.0]))) + def test_mask_along_axis_iid_mask_value(self, mask_param, mask_value, axis, p): + specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device) + mask_value_tensor = torch.tensor(mask_value, dtype=self.dtype, device=self.device) + torch.manual_seed(0) + # as this operation is random we need to fix the seed for results to match + mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value_tensor, axis, p=p) + torch.manual_seed(0) + mask_specgrams_float = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis, p=p) + assert torch.allclose( + mask_specgrams, mask_specgrams_float + ), f"""Masking with float and tensor should be the same diff = { + torch.abs(mask_specgrams - mask_specgrams_float).max()}""" + @parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2]))) def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis): """mask_along_axis should not alter original input Tensor