@@ -100,194 +100,23 @@ void lfilter_core_generic_loop(
100
100
}
101
101
}
102
102
103
- class DifferentiableIIR : public torch ::autograd::Function<DifferentiableIIR> {
104
- public:
105
- static torch::Tensor forward (
106
- torch::autograd::AutogradContext* ctx,
107
- const torch::Tensor& waveform,
108
- const torch::Tensor& a_coeffs_normalized) {
109
- auto device = waveform.device ();
110
- auto dtype = waveform.dtype ();
111
- int64_t n_batch = waveform.size (0 );
112
- int64_t n_channel = waveform.size (1 );
113
- int64_t n_sample = waveform.size (2 );
114
- int64_t n_order = a_coeffs_normalized.size (1 );
115
- int64_t n_sample_padded = n_sample + n_order - 1 ;
116
-
117
- auto a_coeff_flipped = a_coeffs_normalized.flip (1 ).contiguous ();
118
-
119
- auto options = torch::TensorOptions ().dtype (dtype).device (device);
120
- auto padded_output_waveform =
121
- torch::zeros ({n_batch, n_channel, n_sample_padded}, options);
122
-
123
- if (device.is_cpu ()) {
124
- cpu_lfilter_core_loop (waveform, a_coeff_flipped, padded_output_waveform);
125
- } else if (device.is_cuda ()) {
126
- #ifdef USE_CUDA
127
- cuda_lfilter_core_loop (waveform, a_coeff_flipped, padded_output_waveform);
128
- #else
129
- lfilter_core_generic_loop (
130
- waveform, a_coeff_flipped, padded_output_waveform);
131
- #endif
132
- } else {
133
- lfilter_core_generic_loop (
134
- waveform, a_coeff_flipped, padded_output_waveform);
135
- }
136
-
137
- auto output = padded_output_waveform.index (
138
- {torch::indexing::Slice (),
139
- torch::indexing::Slice (),
140
- torch::indexing::Slice (n_order - 1 , torch::indexing::None)});
141
-
142
- ctx->save_for_backward ({waveform, a_coeffs_normalized, output});
143
- return output;
144
- }
145
-
146
- static torch::autograd::tensor_list backward (
147
- torch::autograd::AutogradContext* ctx,
148
- torch::autograd::tensor_list grad_outputs) {
149
- auto saved = ctx->get_saved_variables ();
150
- auto x = saved[0 ];
151
- auto a_coeffs_normalized = saved[1 ];
152
- auto y = saved[2 ];
153
-
154
- int64_t n_channel = x.size (1 );
155
- int64_t n_order = a_coeffs_normalized.size (1 );
156
-
157
- auto dx = torch::Tensor ();
158
- auto da = torch::Tensor ();
159
- auto dy = grad_outputs[0 ];
160
-
161
- namespace F = torch::nn::functional;
162
-
163
- auto tmp =
164
- DifferentiableIIR::apply (dy.flip (2 ).contiguous (), a_coeffs_normalized)
165
- .flip (2 );
166
-
167
- if (x.requires_grad ()) {
168
- dx = tmp;
169
- }
170
-
171
- if (a_coeffs_normalized.requires_grad ()) {
172
- da = -torch::matmul (
173
- tmp.transpose (0 , 1 ).reshape ({n_channel, 1 , -1 }),
174
- F::pad (y, F::PadFuncOptions ({n_order - 1 , 0 }))
175
- .unfold (2 , n_order, 1 )
176
- .transpose (0 , 1 )
177
- .reshape ({n_channel, -1 , n_order}))
178
- .squeeze (1 )
179
- .flip (1 );
180
- }
181
- return {dx, da};
182
- }
183
- };
184
-
185
- class DifferentiableFIR : public torch ::autograd::Function<DifferentiableFIR> {
186
- public:
187
- static torch::Tensor forward (
188
- torch::autograd::AutogradContext* ctx,
189
- const torch::Tensor& waveform,
190
- const torch::Tensor& b_coeffs) {
191
- int64_t n_order = b_coeffs.size (1 );
192
- int64_t n_channel = b_coeffs.size (0 );
193
-
194
- namespace F = torch::nn::functional;
195
- auto b_coeff_flipped = b_coeffs.flip (1 ).contiguous ();
196
- auto padded_waveform =
197
- F::pad (waveform, F::PadFuncOptions ({n_order - 1 , 0 }));
198
-
199
- auto output = F::conv1d (
200
- padded_waveform,
201
- b_coeff_flipped.unsqueeze (1 ),
202
- F::Conv1dFuncOptions ().groups (n_channel));
203
-
204
- ctx->save_for_backward ({waveform, b_coeffs, output});
205
- return output;
206
- }
207
-
208
- static torch::autograd::tensor_list backward (
209
- torch::autograd::AutogradContext* ctx,
210
- torch::autograd::tensor_list grad_outputs) {
211
- auto saved = ctx->get_saved_variables ();
212
- auto x = saved[0 ];
213
- auto b_coeffs = saved[1 ];
214
- auto y = saved[2 ];
215
-
216
- int64_t n_batch = x.size (0 );
217
- int64_t n_channel = x.size (1 );
218
- int64_t n_order = b_coeffs.size (1 );
219
-
220
- auto dx = torch::Tensor ();
221
- auto db = torch::Tensor ();
222
- auto dy = grad_outputs[0 ];
223
-
224
- namespace F = torch::nn::functional;
225
-
226
- if (b_coeffs.requires_grad ()) {
227
- db = F::conv1d (
228
- F::pad (x, F::PadFuncOptions ({n_order - 1 , 0 }))
229
- .view ({1 , n_batch * n_channel, -1 }),
230
- dy.view ({n_batch * n_channel, 1 , -1 }),
231
- F::Conv1dFuncOptions ().groups (n_batch * n_channel))
232
- .view ({n_batch, n_channel, -1 })
233
- .sum (0 )
234
- .flip (1 );
235
- }
236
-
237
- if (x.requires_grad ()) {
238
- dx = F::conv1d (
239
- F::pad (dy, F::PadFuncOptions ({0 , n_order - 1 })),
240
- b_coeffs.unsqueeze (1 ),
241
- F::Conv1dFuncOptions ().groups (n_channel));
242
- }
243
-
244
- return {dx, db};
245
- }
246
- };
247
-
248
- torch::Tensor lfilter_core (
249
- const torch::Tensor& waveform,
250
- const torch::Tensor& a_coeffs,
251
- const torch::Tensor& b_coeffs) {
252
- TORCH_CHECK (waveform.device () == a_coeffs.device ());
253
- TORCH_CHECK (b_coeffs.device () == a_coeffs.device ());
254
- TORCH_CHECK (a_coeffs.sizes () == b_coeffs.sizes ());
255
-
256
- TORCH_INTERNAL_ASSERT (waveform.sizes ().size () == 3 );
257
- TORCH_INTERNAL_ASSERT (a_coeffs.sizes ().size () == 2 );
258
- TORCH_INTERNAL_ASSERT (a_coeffs.size (0 ) == waveform.size (1 ));
259
-
260
- int64_t n_order = b_coeffs.size (1 );
261
-
262
- TORCH_INTERNAL_ASSERT (n_order > 0 );
263
-
264
- auto filtered_waveform = DifferentiableFIR::apply (
265
- waveform,
266
- b_coeffs /
267
- a_coeffs.index (
268
- {torch::indexing::Slice (), torch::indexing::Slice (0 , 1 )}));
103
+ } // namespace
269
104
270
- auto output = DifferentiableIIR::apply (
271
- filtered_waveform,
272
- a_coeffs /
273
- a_coeffs.index (
274
- {torch::indexing::Slice (), torch::indexing::Slice (0 , 1 )}));
275
- return output;
105
+ TORCH_LIBRARY (torchaudio, m) {
106
+ m.def (
107
+ " torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()" );
276
108
}
277
109
278
- } // namespace
279
-
280
- // Note: We want to avoid using "catch-all" kernel.
281
- // The following registration should be replaced with CPU specific registration.
282
- TORCH_LIBRARY_FRAGMENT (torchaudio, m) {
283
- m.def (" torchaudio::_lfilter_core_loop" , &cpu_lfilter_core_loop);
110
+ TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
111
+ m.impl (" torchaudio::_lfilter_core_loop" , &cpu_lfilter_core_loop);
284
112
}
285
113
286
- TORCH_LIBRARY (torchaudio, m) {
287
- m. def (
288
- " torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor " );
114
+ # ifdef USE_CUDA
115
+ TORCH_LIBRARY_IMPL (torchaudio, CUDA, m) {
116
+ m. impl ( " torchaudio::_lfilter_core_loop " , &cuda_lfilter_core_loop );
289
117
}
118
+ #endif
290
119
291
- TORCH_LIBRARY_IMPL (torchaudio, CompositeImplicitAutograd , m) {
292
- m.impl (" torchaudio::_lfilter " , lfilter_core );
120
+ TORCH_LIBRARY_IMPL (torchaudio, CompositeExplicitAutograd , m) {
121
+ m.impl (" torchaudio::_lfilter_core_loop " , &lfilter_core_generic_loop );
293
122
}
0 commit comments