Skip to content

Commit 0d5cd34

Browse files
committed
fix(rtrvc): skip head unimplemented
1 parent df83554 commit 0d5cd34

File tree

5 files changed

+32
-38
lines changed

5 files changed

+32
-38
lines changed

infer/lib/rtrvc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def infer(
138138
self,
139139
input_wav: torch.Tensor,
140140
block_frame_16k: int,
141-
skip_head: torch.Tensor,
141+
skip_head: int,
142142
return_length: int,
143143
f0method: Union[tuple, str],
144144
inp_f0: Optional[np.ndarray] = None,
@@ -241,8 +241,6 @@ def infer(
241241
feats = feats.to(feats0.dtype)
242242
p_len = torch.LongTensor([p_len]).to(self.device)
243243
sid = torch.LongTensor([0]).to(self.device)
244-
skip_head = torch.LongTensor([skip_head])
245-
return_length = torch.LongTensor([return_length])
246244
with torch.no_grad():
247245
infered_audio = (
248246
self.net_g.infer(
@@ -253,6 +251,7 @@ def infer(
253251
pitchf=cache_pitchf,
254252
skip_head=skip_head,
255253
return_length=return_length,
254+
return_length2=return_length2,
256255
)
257256
.squeeze(1)
258257
.float()

rvc/layers/encoders.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,21 @@ def __call__(
123123
phone: torch.Tensor,
124124
pitch: torch.Tensor,
125125
lengths: torch.Tensor,
126-
# skip_head: Optional[torch.Tensor] = None,
126+
skip_head: Optional[int] = None,
127127
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128128
return super().__call__(
129129
phone,
130130
pitch,
131131
lengths,
132-
# skip_head=skip_head,
132+
skip_head=skip_head,
133133
)
134134

135135
def forward(
136136
self,
137137
phone: torch.Tensor,
138138
pitch: torch.Tensor,
139139
lengths: torch.Tensor,
140-
# skip_head: Optional[torch.Tensor] = None,
140+
skip_head: Optional[int] = None,
141141
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
142142
x = self.emb_phone(phone)
143143
if pitch is not None:
@@ -150,13 +150,10 @@ def forward(
150150
1,
151151
).to(x.dtype)
152152
x = self.encoder(x * x_mask, x_mask)
153-
"""
154153
if skip_head is not None:
155-
assert isinstance(skip_head, torch.Tensor)
156-
head = int(skip_head.item())
154+
head = int(skip_head)
157155
x = x[:, :, head:]
158156
x_mask = x_mask[:, :, head:]
159-
"""
160157
stats: torch.Tensor = self.proj(x) * x_mask
161158
m, logs = torch.split(stats, self.out_channels, dim=1)
162159
return m, logs, x_mask

rvc/layers/generators.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,21 @@ def __call__(
6161
self,
6262
x: torch.Tensor,
6363
g: Optional[torch.Tensor] = None,
64-
# n_res: Optional[torch.Tensor] = None,
64+
n_res: Optional[int] = None,
6565
) -> torch.Tensor:
66-
return super().__call__(x, g=g)
66+
return super().__call__(x, g=g, n_res=n_res)
6767

6868
def forward(
6969
self,
7070
x: torch.Tensor,
7171
g: Optional[torch.Tensor] = None,
72-
# n_res: Optional[torch.Tensor] = None,
72+
n_res: Optional[int] = None,
7373
):
74-
"""
7574
if n_res is not None:
76-
assert isinstance(n_res, torch.Tensor)
77-
n = int(n_res.item())
75+
n = int(n_res)
7876
if n != x.shape[-1]:
7977
x = F.interpolate(x, size=n, mode="linear")
80-
"""
78+
8179
x = self.conv_pre(x)
8280
if g is not None:
8381
x = x + self.cond(g)

rvc/layers/nsf.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,28 +136,27 @@ def __call__(
136136
x: torch.Tensor,
137137
f0: torch.Tensor,
138138
g: Optional[torch.Tensor] = None,
139-
# n_res: Optional[torch.Tensor] = None,
139+
n_res: Optional[int] = None,
140140
) -> torch.Tensor:
141-
return super().__call__(x, f0, g=g)
141+
return super().__call__(x, f0, g=g, n_res=n_res)
142142

143143
def forward(
144144
self,
145145
x: torch.Tensor,
146146
f0: torch.Tensor,
147147
g: Optional[torch.Tensor] = None,
148-
# n_res: Optional[torch.Tensor] = None,
148+
n_res: Optional[int] = None,
149149
) -> torch.Tensor:
150150
har_source = self.m_source(f0, self.upp)
151151
har_source = har_source.transpose(1, 2)
152-
"""
152+
153153
if n_res is not None:
154-
assert isinstance(n_res, torch.Tensor)
155-
n = int(n_res.item())
156-
if n * self.upp != har_source.shape[-1]:
157-
har_source = F.interpolate(har_source, size=n * self.upp, mode="linear")
158-
if n != x.shape[-1]:
159-
x = F.interpolate(x, size=n, mode="linear")
160-
"""
154+
n_res = int(n_res)
155+
if n_res * self.upp != har_source.shape[-1]:
156+
har_source = F.interpolate(har_source, size=n_res * self.upp, mode="linear")
157+
if n_res != x.shape[-1]:
158+
x = F.interpolate(x, size=n_res, mode="linear")
159+
161160
x = self.conv_pre(x)
162161
if g is not None:
163162
x = x + self.cond(g)

rvc/layers/synthesizers.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,18 @@ def infer(
177177
sid: torch.Tensor,
178178
pitch: Optional[torch.Tensor] = None,
179179
pitchf: Optional[torch.Tensor] = None, # nsff0
180-
skip_head: Optional[torch.Tensor] = None,
181-
return_length: Optional[torch.Tensor] = None,
182-
# return_length2: Optional[torch.Tensor] = None,
180+
skip_head: Optional[int] = None,
181+
return_length: Optional[int] = None,
182+
return_length2: Optional[int] = None,
183183
):
184184
g = self.emb_g(sid).unsqueeze(-1)
185185
if skip_head is not None and return_length is not None:
186-
head = int(skip_head.item())
187-
length = int(return_length.item())
188-
flow_head = torch.clamp(skip_head - 24, min=0)
189-
dec_head = head - int(flow_head.item())
190-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
186+
head = int(skip_head)
187+
length = int(return_length)
188+
flow_head = head - 24
189+
if flow_head < 0: flow_head = 0
190+
dec_head = head - flow_head
191+
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, head)
191192
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
192193
z = self.flow(z_p, x_mask, g=g, reverse=True)
193194
z = z[:, :, dec_head : dec_head + length]
@@ -204,13 +205,13 @@ def infer(
204205
z * x_mask,
205206
pitchf,
206207
g=g,
207-
# n_res=return_length2,
208+
n_res=return_length2,
208209
)
209210
else:
210211
o = self.dec(
211212
z * x_mask,
212213
g=g,
213-
# n_res=return_length2
214+
n_res=return_length2
214215
)
215216
del x_mask, z
216217
return o # , x_mask, (z, z_p, m_p, logs_p)

0 commit comments

Comments
 (0)