@@ -51,7 +51,7 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None,
5151 self .win_length = win_length if win_length is not None else n_fft
5252 self .hop_length = hop_length if hop_length is not None else self .win_length // 2
5353 window = window_fn (self .win_length ) if wkwargs is None else window_fn (self .win_length , ** wkwargs )
54- self .window = window
54+ self .register_buffer ( ' window' , window )
5555 self .pad = pad
5656 self .power = power
5757 self .normalized = normalized
@@ -136,7 +136,7 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N
136136
137137 fb = torch .empty (0 ) if n_stft is None else F .create_fb_matrix (
138138 n_stft , self .f_min , self .f_max , self .n_mels , self .sample_rate )
139- self .fb = fb
139+ self .register_buffer ( 'fb' , fb )
140140
141141 def forward (self , specgram ):
142142 r"""
@@ -260,7 +260,7 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m
260260 if self .n_mfcc > self .MelSpectrogram .n_mels :
261261 raise ValueError ('Cannot select more MFCC coefficients than # mel bins' )
262262 dct_mat = F .create_dct (self .n_mfcc , self .MelSpectrogram .n_mels , self .norm )
263- self .dct_mat = dct_mat
263+ self .register_buffer ( ' dct_mat' , dct_mat )
264264 self .log_mels = log_mels
265265
266266 def forward (self , waveform ):
0 commit comments