1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+
1415""" Augmentation on spectrogram: http://arxiv.org/abs/1904.08779 """
16+
1517import numpy as np
18+ import tensorflow as tf
1619
1720from nlpaug .flow import Sequential
1821from nlpaug .util import Action
1922from nlpaug .model .spectrogram import Spectrogram
2023from nlpaug .augmenter .spectrogram import SpectrogramAugmenter
2124
25+ from ..utils .utils import shape_list
26+
2227# ---------------------------- FREQ MASKING ----------------------------
2328
2429
@@ -75,6 +80,35 @@ def __init__(self,
7580 def substitute (self , data ):
7681 return self .flow .augment (data )
7782
83+
84+ class TFFreqMasking :
85+ def __init__ (self , num_masks : int = 1 , mask_factor : float = 27 ):
86+ self .num_masks = num_masks
87+ self .mask_factor = mask_factor
88+
89+ @tf .function
90+ def augment (self , spectrogram : tf .Tensor ):
91+ """
92+ Masking the frequency channels (shape[1])
93+ Args:
94+ spectrogram: shape (T, num_feature_bins, V)
95+ Returns:
96+ frequency masked spectrogram
97+ """
98+ T , F , V = shape_list (spectrogram , out_type = tf .int32 )
99+ for _ in range (self .num_masks ):
100+ f = tf .random .uniform ([], minval = 0 , maxval = self .mask_factor , dtype = tf .int32 )
101+ f = tf .minimum (f , F )
102+ f0 = tf .random .uniform ([], minval = 0 , maxval = (F - f ), dtype = tf .int32 )
103+ mask = tf .concat ([
104+ tf .ones ([T , f0 , V ], dtype = spectrogram .dtype ),
105+ tf .zeros ([T , f , V ], dtype = spectrogram .dtype ),
106+ tf .ones ([T , F - f0 - f , V ], dtype = spectrogram .dtype )
107+ ], axis = 1 )
108+ spectrogram = spectrogram * mask
109+ return spectrogram
110+
111+
78112# ---------------------------- TIME MASKING ----------------------------
79113
80114
@@ -101,9 +135,8 @@ def mask(self, data: np.ndarray) -> np.ndarray:
101135 """
102136 spectrogram = data .copy ()
103137 time = np .random .randint (0 , self .mask_factor + 1 )
104- time = min (time , spectrogram .shape [0 ])
105- time0 = np .random .randint (0 , spectrogram .shape [0 ] - time + 1 )
106138 time = min (time , int (self .p_upperbound * spectrogram .shape [0 ]))
139+ time0 = np .random .randint (0 , spectrogram .shape [0 ] - time + 1 )
107140 spectrogram [time0 :time0 + time , :, :] = 0
108141 return spectrogram
109142
@@ -139,3 +172,32 @@ def __init__(self,
139172
140173 def substitute (self , data ):
141174 return self .flow .augment (data )
175+
176+
177+ class TFTimeMasking :
178+ def __init__ (self , num_masks : int = 1 , mask_factor : float = 100 , p_upperbound : float = 1.0 ):
179+ self .num_masks = num_masks
180+ self .mask_factor = mask_factor
181+ self .p_upperbound = p_upperbound
182+
183+ @tf .function
184+ def augment (self , spectrogram : tf .Tensor ):
185+ """
186+ Masking the time channel (shape[0])
187+ Args:
188+ spectrogram: shape (T, num_feature_bins, V)
189+ Returns:
190+ frequency masked spectrogram
191+ """
192+ T , F , V = shape_list (spectrogram , out_type = tf .int32 )
193+ for _ in range (self .num_masks ):
194+ t = tf .random .uniform ([], minval = 0 , maxval = self .mask_factor , dtype = tf .int32 )
195+ t = tf .minimum (t , tf .cast (tf .cast (T , dtype = tf .float32 ) * self .p_upperbound , dtype = tf .int32 ))
196+ t0 = tf .random .uniform ([], minval = 0 , maxval = (T - t ), dtype = tf .int32 )
197+ mask = tf .concat ([
198+ tf .ones ([t0 , F , V ], dtype = spectrogram .dtype ),
199+ tf .zeros ([t , F , V ], dtype = spectrogram .dtype ),
200+ tf .ones ([T - t0 - t , F , V ], dtype = spectrogram .dtype )
201+ ], axis = 0 )
202+ spectrogram = spectrogram * mask
203+ return spectrogram
0 commit comments