@@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
946
946
b_coeff_flipped = b_coeffs .flip (1 ).contiguous ()
947
947
padded_waveform = F .pad (waveform , (n_order - 1 , 0 ))
948
948
output = F .conv1d (padded_waveform , b_coeff_flipped .unsqueeze (1 ), groups = n_channel )
949
- ctx .save_for_backward (waveform , b_coeffs , output )
949
+ if not torch .jit .is_scripting ():
950
+ ctx .save_for_backward (waveform , b_coeffs , output )
950
951
return output
951
952
952
953
@staticmethod
@@ -955,21 +956,28 @@ def backward(ctx, dy):
955
956
n_batch = x .size (0 )
956
957
n_channel = x .size (1 )
957
958
n_order = b_coeffs .size (1 )
958
- db = (
959
- F .conv1d (
960
- F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
961
- dy .view (n_batch * n_channel , 1 , - 1 ),
962
- groups = n_batch * n_channel ,
963
- )
964
- . view ( n_batch , n_channel , - 1 )
965
- .sum (0 )
966
- . flip ( 1 )
967
- if b_coeffs . requires_grad
968
- else None
969
- )
970
- dx = F . conv1d ( F . pad ( dy , ( 0 , n_order - 1 )), b_coeffs . unsqueeze ( 1 ), groups = n_channel ) if x .requires_grad else None
959
+
960
+ db = F .conv1d (
961
+ F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
962
+ dy .view (n_batch * n_channel , 1 , - 1 ),
963
+ groups = n_batch * n_channel
964
+ ). view (
965
+ n_batch , n_channel , - 1
966
+ ) .sum (0 ). flip ( 1 ) if b_coeffs . requires_grad else None
967
+ dx = F . conv1d (
968
+ F . pad ( dy , ( 0 , n_order - 1 )),
969
+ b_coeffs . unsqueeze ( 1 ),
970
+ groups = n_channel
971
+ ) if x .requires_grad else None
971
972
return (dx , db )
972
973
974
+ @staticmethod
975
+ def ts_apply (waveform , b_coeffs ):
976
+ if torch .jit .is_scripting ():
977
+ return DifferentiableFIR .forward (torch .empty (0 ), waveform , b_coeffs )
978
+ else :
979
+ return DifferentiableFIR .apply (waveform , b_coeffs )
980
+
973
981
974
982
class DifferentiableIIR (torch .autograd .Function ):
975
983
@staticmethod
@@ -984,7 +992,8 @@ def forward(ctx, waveform, a_coeffs_normalized):
984
992
)
985
993
_lfilter_core_loop (waveform , a_coeff_flipped , padded_output_waveform )
986
994
output = padded_output_waveform [:, :, n_order - 1 :]
987
- ctx .save_for_backward (waveform , a_coeffs_normalized , output )
995
+ if not torch .jit .is_scripting ():
996
+ ctx .save_for_backward (waveform , a_coeffs_normalized , output )
988
997
return output
989
998
990
999
@staticmethod
@@ -1006,10 +1015,17 @@ def backward(ctx, dy):
1006
1015
)
1007
1016
return (dx , da )
1008
1017
1018
+ @staticmethod
1019
+ def ts_apply (waveform , a_coeffs_normalized ):
1020
+ if torch .jit .is_scripting ():
1021
+ return DifferentiableIIR .forward (torch .empty (0 ), waveform , a_coeffs_normalized )
1022
+ else :
1023
+ return DifferentiableIIR .apply (waveform , a_coeffs_normalized )
1024
+
1009
1025
1010
1026
def _lfilter (waveform , a_coeffs , b_coeffs ):
1011
- filtered_waveform = DifferentiableFIR .apply (waveform , b_coeffs / a_coeffs [:, 0 :1 ])
1012
- return DifferentiableIIR .apply (filtered_waveform , a_coeffs / a_coeffs [:, 0 :1 ])
1027
+ filtered_waveform = DifferentiableFIR .ts_apply (waveform , b_coeffs / a_coeffs [:, 0 :1 ])
1028
+ return DifferentiableIIR .ts_apply (filtered_waveform , a_coeffs / a_coeffs [:, 0 :1 ])
1013
1029
1014
1030
1015
1031
def lfilter (waveform : Tensor , a_coeffs : Tensor , b_coeffs : Tensor , clamp : bool = True , batching : bool = True ) -> Tensor :
0 commit comments