Skip to content

Commit 83a6780

Browse files
ynonaolgafacebook-github-bot
authored andcommitted
[Pyspeech] avoid .item call in feature_processing
Summary: tensor.item() causes CPU_GPU sync and degradates performance. Avoid using .item() and rewrite mask_along_axis_iid to treat mask_value as a tensor comparison https://fburl.com/mlhub/37l2ufrp {F1980844270} It increases SM U by 1% {F1980890373} {F1980890377} Differential Revision: D79451615
1 parent 9b57c7b commit 83a6780

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/torchaudio/functional/functional.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int:
817817
def mask_along_axis_iid(
818818
specgrams: Tensor,
819819
mask_param: int,
820-
mask_value: float,
820+
mask_value: Union[float, Tensor],
821821
axis: int,
822822
p: float = 1.0,
823823
) -> Tensor:
@@ -874,7 +874,12 @@ def mask_along_axis_iid(
874874

875875
# Per batch example masking
876876
specgrams = specgrams.transpose(axis, -1)
877-
specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
877+
# this aims to avoid CPU-GPU sync from upstream
878+
specgrams = (
879+
torch.where((mask >= mask_start) & (mask < mask_end), mask_value.repeat(*specgrams.shape), specgrams)
880+
if isinstance(mask_value, Tensor)
881+
else specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
882+
)
878883
specgrams = specgrams.transpose(axis, -1)
879884

880885
return specgrams

src/torchaudio/transforms/_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0)
11851185
self.iid_masks = iid_masks
11861186
self.p = p
11871187

1188-
def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
1188+
def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0) -> Tensor:
11891189
r"""
11901190
Args:
11911191
specgram (Tensor): Tensor of dimension `(..., freq, time)`.

0 commit comments

Comments
 (0)