@@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
946
946
b_coeff_flipped = b_coeffs .flip (1 ).contiguous ()
947
947
padded_waveform = F .pad (waveform , (n_order - 1 , 0 ))
948
948
output = F .conv1d (padded_waveform , b_coeff_flipped .unsqueeze (1 ), groups = n_channel )
949
- ctx .save_for_backward (waveform , b_coeffs , output )
949
+ if not torch .jit .is_scripting ():
950
+ ctx .save_for_backward (waveform , b_coeffs , output )
950
951
return output
951
952
952
953
@staticmethod
@@ -956,32 +957,41 @@ def backward(ctx, dy):
956
957
n_channel = x .size (1 )
957
958
n_order = b_coeffs .size (1 )
958
959
db = F .conv1d (
959
- F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
960
- dy .view (n_batch * n_channel , 1 , - 1 ),
961
- groups = n_batch * n_channel
962
- ).view (
963
- n_batch , n_channel , - 1
964
- ).sum (0 ).flip (1 ) if b_coeffs .requires_grad else None
960
+ F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
961
+ dy .view (n_batch * n_channel , 1 , - 1 ),
962
+ groups = n_batch * n_channel
963
+ ).view (
964
+ n_batch , n_channel , - 1
965
+ ).sum (0 ).flip (1 ) if b_coeffs .requires_grad else None
965
966
dx = F .conv1d (
966
- F .pad (dy , (0 , n_order - 1 )),
967
- b_coeffs .unsqueeze (1 ),
968
- groups = n_channel
969
- ) if x .requires_grad else None
967
+ F .pad (dy , (0 , n_order - 1 )),
968
+ b_coeffs .unsqueeze (1 ),
969
+ groups = n_channel
970
+ ) if x .requires_grad else None
970
971
return (dx , db )
971
972
973
+ @staticmethod
974
+ def ts_apply (waveform , b_coeffs ):
975
+ if torch .jit .is_scripting ():
976
+ return DifferentiableFIR .forward (torch .empty (0 ), waveform , b_coeffs )
977
+ else :
978
+ return DifferentiableFIR .apply (waveform , b_coeffs )
979
+
980
+
972
981
class DifferentiableIIR (torch .autograd .Function ):
973
982
@staticmethod
974
983
def forward (ctx , waveform , a_coeffs_normalized ):
975
984
n_batch , n_channel , n_sample = waveform .shape
976
985
n_order = a_coeffs_normalized .size (1 )
977
986
n_sample_padded = n_sample + n_order - 1
978
987
979
- a_coeff_flipped = a_coeffs_normalized .flip (1 ).contiguous ();
988
+ a_coeff_flipped = a_coeffs_normalized .flip (1 ).contiguous ()
980
989
padded_output_waveform = torch .zeros (n_batch , n_channel , n_sample_padded ,
981
- device = waveform .device , dtype = waveform .dtype )
990
+ device = waveform .device , dtype = waveform .dtype )
982
991
_lfilter_core_loop (waveform , a_coeff_flipped , padded_output_waveform )
983
- output = padded_output_waveform [:,:,n_order - 1 :]
984
- ctx .save_for_backward (waveform , a_coeffs_normalized , output )
992
+ output = padded_output_waveform [:, :, n_order - 1 :]
993
+ if not torch .jit .is_scripting ():
994
+ ctx .save_for_backward (waveform , a_coeffs_normalized , output )
985
995
return output
986
996
987
997
@staticmethod
@@ -992,15 +1002,23 @@ def backward(ctx, dy):
992
1002
tmp = DifferentiableIIR .apply (dy .flip (2 ).contiguous (), a_coeffs_normalized ).flip (2 )
993
1003
dx = tmp if x .requires_grad else None
994
1004
da = - (tmp .transpose (0 , 1 ).reshape (n_channel , 1 , - 1 ) @
995
- F .pad (y , (n_order - 1 , 0 )).unfold (2 , n_order , 1 ).transpose (0 ,1 )
996
- .reshape (n_channel , - 1 , n_order )
997
- ).squeeze (1 ).flip (1 ) if a_coeffs_normalized .requires_grad else None
1005
+ F .pad (y , (n_order - 1 , 0 )).unfold (2 , n_order , 1 ).transpose (0 , 1 )
1006
+ .reshape (n_channel , - 1 , n_order )
1007
+ ).squeeze (1 ).flip (1 ) if a_coeffs_normalized .requires_grad else None
998
1008
return (dx , da )
999
1009
1010
+ @staticmethod
1011
+ def ts_apply (waveform , a_coeffs_normalized ):
1012
+ if torch .jit .is_scripting ():
1013
+ return DifferentiableIIR .forward (torch .empty (0 ), waveform , a_coeffs_normalized )
1014
+ else :
1015
+ return DifferentiableIIR .apply (waveform , a_coeffs_normalized )
1016
+
1017
+
1000
1018
def _lfilter (waveform , a_coeffs , b_coeffs ):
1001
- n_order = b_coeffs . size ( 1 )
1002
- filtered_waveform = DifferentiableFIR . apply ( waveform , b_coeffs / a_coeffs [:, 0 :1 ])
1003
- return DifferentiableIIR . apply ( filtered_waveform , a_coeffs / a_coeffs [:, 0 : 1 ])
1019
+ filtered_waveform = DifferentiableFIR . ts_apply ( waveform , b_coeffs / a_coeffs [:, 0 : 1 ] )
1020
+ return DifferentiableIIR . ts_apply ( filtered_waveform , a_coeffs / a_coeffs [:, 0 :1 ])
1021
+
1004
1022
1005
1023
def lfilter (waveform : Tensor , a_coeffs : Tensor , b_coeffs : Tensor , clamp : bool = True , batching : bool = True ) -> Tensor :
1006
1024
r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
@@ -1071,6 +1089,7 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =
1071
1089
1072
1090
return output
1073
1091
1092
+
1074
1093
def lowpass_biquad (waveform : Tensor , sample_rate : int , cutoff_freq : float , Q : float = 0.707 ) -> Tensor :
1075
1094
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
1076
1095
0 commit comments