@@ -123,19 +123,23 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
123123 r"""
124124 Args:
125125 waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples],
126- where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30
127126
128127 Returns:
129- torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames]
130- [1, 80, 3000] with default options
128+ torch.Tensor: Output of shape [1, feature_size, nb_max_frames * n_chunks]
129+ n_chunks is the number of chunks of `sampling_rate` samples in the input waveform.
130+ [1, 80, 3000] with default options and 1 chunk
131131 """
132- # TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec)
132+ n_chunks = ( waveform . shape [ 0 ] - 1 ) // self . n_samples + 1
133133 waveform = F .pad (
134134 waveform ,
135- (0 , self .n_samples - waveform .shape [0 ] - 1 ),
135+ (0 , self .n_samples * n_chunks - waveform .shape [0 ]),
136136 mode = "constant" ,
137137 value = self .padding_value ,
138138 )
139+ # Ideally we should do:
140+ # window = torch.hann_window(self.n_fft)
141+ # but this is not currently supported when lowering.
142+ # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
139143 window = 0.5 * (
140144 1
141145 - torch .cos (
@@ -145,10 +149,6 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
145149 / self .n_fft
146150 )
147151 )
148- # Ideally we should do instead
149- # window = torch.hann_window(self.n_fft)
150- # but this is not currently supported when lowering
151- # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
152152 stft = torch .stft (
153153 waveform ,
154154 n_fft = self .n_fft ,
@@ -157,7 +157,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
157157 center = True ,
158158 return_complex = True ,
159159 )
160- magnitudes = torch .abs (stft ) ** 2 # pyre-ignore[58]
160+ magnitudes = torch .abs (stft )[..., : - 1 ] ** 2 # pyre-ignore[58]
161161
162162 mel_spec = self .mel_filters @ magnitudes
163163
@@ -173,8 +173,7 @@ def export_processor():
173173 audio_tensor = torch .randn (480000 )
174174 chunk_tensor = audio_tensor [:93680 ]
175175 with torch .no_grad ():
176- # export. What is the min of waveforms?
177- dim = Dim ("waveform" , min = 1600 , max = audio_tensor .size (0 ))
176+ dim = Dim ("waveform" , min = 1600 , max = audio_tensor .size (0 ) * 10 ) # 10 chunks max
178177 ep : ExportedProgram = export (
179178 model , (chunk_tensor ,), dynamic_shapes = {"waveform" : {0 : dim }}, strict = True
180179 )
0 commit comments