@@ -120,41 +120,20 @@ def _f02uv(self, f0):
120120 uv = uv * (f0 > self .voiced_threshold )
121121 return uv
122122
123- def _f02sine (self , f0_values , upp ):
124- """ f0_values : (batchsize, length, dim)
123+ def _f02sine (self , f0 , upp ):
124+ """ f0 : (batchsize, length, dim)
125125 where dim indicates fundamental tone and overtones
126126 """
127- rad_values = (f0_values / self .sampling_rate ).fmod (1. ) # %1意味着n_har的乘积无法后处理优化
128- rand_ini = torch .rand (1 , self .dim , device = f0_values .device )
129- rand_ini [:, 0 ] = 0
130- rad_values [:, 0 , :] += rand_ini
131- is_half = rad_values .dtype is not torch .float32
132- tmp_over_one = torch .cumsum (rad_values .double (), 1 ) # % 1 #####%1意味着后面的cumsum无法再优化
133- if is_half :
134- tmp_over_one = tmp_over_one .half ()
135- else :
136- tmp_over_one = tmp_over_one .float ()
137- tmp_over_one *= upp
138- tmp_over_one = F .interpolate (
139- tmp_over_one .transpose (2 , 1 ), scale_factor = upp ,
140- mode = 'linear' , align_corners = True
141- ).transpose (2 , 1 )
142- rad_values = F .interpolate (rad_values .transpose (2 , 1 ), scale_factor = upp , mode = 'nearest' ).transpose (2 , 1 )
143- tmp_over_one = tmp_over_one .fmod (1. )
144- diff = F .conv2d (
145- tmp_over_one .unsqueeze (1 ), torch .FloatTensor ([[[[- 1. ], [1. ]]]]).to (tmp_over_one .device ),
146- stride = (1 , 1 ), padding = 0 , dilation = (1 , 1 )
147- ).squeeze (1 ) # Equivalent to torch.diff, but able to export ONNX
148- cumsum_shift = (diff < 0 ).double ()
149- cumsum_shift = torch .cat ((
150- torch .zeros ((f0_values .size ()[0 ], 1 , self .dim ), dtype = torch .double ).to (f0_values .device ),
151- cumsum_shift
152- ), dim = 1 )
153- sines = torch .sin (torch .cumsum (rad_values .double () + cumsum_shift , dim = 1 ) * 2 * np .pi )
154- if is_half :
155- sines = sines .half ()
156- else :
157- sines = sines .float ()
127+ rad = f0 / self .sampling_rate * torch .arange (1 , upp + 1 , device = f0 .device )
128+ rad2 = torch .fmod (rad [..., - 1 :].float () + 0.5 , 1.0 ) - 0.5
129+ rad_acc = rad2 .cumsum (dim = 1 ).fmod (1.0 ).to (f0 )
130+ rad += F .pad (rad_acc , (0 , 0 , 1 , - 1 ))
131+ rad = rad .reshape (f0 .shape [0 ], - 1 , 1 )
132+ rad = torch .multiply (rad , torch .arange (1 , self .dim + 1 , device = f0 .device ).reshape (1 , 1 , - 1 ))
133+ rand_ini = torch .rand (1 , 1 , self .dim , device = f0 .device )
134+ rand_ini [..., 0 ] = 0
135+ rad += rand_ini
136+ sines = torch .sin (2 * np .pi * rad )
158137 return sines
159138
160139 @torch .no_grad ()
@@ -166,8 +145,7 @@ def forward(self, f0, upp):
166145 output uv: tensor(batchsize=1, length, 1)
167146 """
168147 f0 = f0 .unsqueeze (- 1 )
169- fn = torch .multiply (f0 , torch .arange (1 , self .dim + 1 , device = f0 .device ).reshape ((1 , 1 , - 1 )))
170- sine_waves = self ._f02sine (fn , upp ) * self .sine_amp
148+ sine_waves = self ._f02sine (f0 , upp ) * self .sine_amp
171149 uv = (f0 > self .voiced_threshold ).float ()
172150 uv = F .interpolate (uv .transpose (2 , 1 ), scale_factor = upp , mode = 'nearest' ).transpose (2 , 1 )
173151 noise_amp = uv * self .noise_std + (1 - uv ) * self .sine_amp / 3
0 commit comments