Skip to content

Commit 717c8bd

Browse files
ynonaolgafacebook-github-bot
authored andcommitted
[Torchaudio] allow masked_value to be tensor (#4015)
Summary: Add support for mask_value to be tensor: for example, if we want to fill with mean values of featured. when mask_value is tensor masked_fill creates additional D2D copy, rewriting with torch.where to avoid it Differential Revision: D79451615
1 parent 89cf282 commit 717c8bd

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

src/torchaudio/functional/functional.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int:
817817
def mask_along_axis_iid(
818818
specgrams: Tensor,
819819
mask_param: int,
820-
mask_value: float,
820+
mask_value: Union[float, Tensor],
821821
axis: int,
822822
p: float = 1.0,
823823
) -> Tensor:
@@ -874,7 +874,12 @@ def mask_along_axis_iid(
874874

875875
# Per batch example masking
876876
specgrams = specgrams.transpose(axis, -1)
877-
specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
877+
# this aims to avoid CPU-GPU sync from upstream
878+
specgrams = (
879+
torch.where((mask >= mask_start) & (mask < mask_end), mask_value.repeat(*specgrams.shape), specgrams)
880+
if isinstance(mask_value, Tensor)
881+
else specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
882+
)
878883
specgrams = specgrams.transpose(axis, -1)
879884

880885
return specgrams

src/torchaudio/transforms/_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0)
11851185
self.iid_masks = iid_masks
11861186
self.p = p
11871187

1188-
def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
1188+
def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0) -> Tensor:
11891189
r"""
11901190
Args:
11911191
specgram (Tensor): Tensor of dimension `(..., freq, time)`.

test/torchaudio_unittest/functional/functional_impl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,20 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis, p):
456456
assert mask_specgrams.size() == specgrams.size()
457457
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
458458

459+
@parameterized.expand(list(itertools.product([100], [torch.tensor(0.0), torch.tensor(30.0)], [2, 3], [0.2, 1.0])))
460+
def test_mask_along_axis_iid_mask_value(self, mask_param, mask_value, axis, p):
461+
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
462+
torch.manual_seed(0)
463+
mask_value.to(self.device)
464+
# as this operation is random we need to fix the seed for results to match
465+
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis, p=p)
466+
torch.manual_seed(0)
467+
mask_specgrams_float = F.mask_along_axis_iid(specgrams, mask_param, mask_value.item(), axis, p=p)
468+
assert torch.allclose(
469+
mask_specgrams, mask_specgrams_float
470+
), f"""Masking with float and tensor should be the same diff = {
471+
torch.abs(mask_specgrams - mask_specgrams_float).max()}"""
472+
459473
@parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2])))
460474
def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis):
461475
"""mask_along_axis should not alter original input Tensor

0 commit comments

Comments
 (0)