From 8fdbd7f454e556cf5586a69f4f9b1bce73b621e1 Mon Sep 17 00:00:00 2001 From: primepake Date: Mon, 23 Jun 2025 14:43:51 +0000 Subject: [PATCH 1/2] adding STG for audio inference --- cosyvoice/flow/decoder.py | 87 +++++++++++++++++++++++++++------ cosyvoice/flow/flow_matching.py | 73 +++++++++++++++++++++++++-- third_party/Matcha-TTS | 2 +- 3 files changed, 141 insertions(+), 21 deletions(-) diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 97768a459..ae76f2a69 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -97,6 +97,10 @@ def __init__( num_mid_blocks=2, num_heads=4, act_fn="snake", + stg_applied_layers_idx=None, + stg_scale=0.0, + do_rescaling=False, + stg_mode="attention" ): """ This decoder requires an input with the same shape of the target. So, if your text content @@ -114,6 +118,12 @@ def __init__( time_embed_dim=time_embed_dim, act_fn="silu", ) + + self.stg_applied_layers_idx = stg_applied_layers_idx or [] + self.stg_scale = stg_scale + self.do_rescaling = do_rescaling + self.stg_mode = stg_mode + self.down_blocks = nn.ModuleList([]) self.mid_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -238,6 +248,8 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): hiddens = [] masks = [mask] + layer_idx = 0 + for resnet, transformer_blocks, downsample in self.down_blocks: mask_down = masks[-1] x = resnet(x, mask_down, t) @@ -245,11 +257,26 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) + if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: + if self.stg_mode == "attention": + x = transformer_block.forward_with_stg( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + else: + x = transformer_block.forward_with_stg_residual( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + else: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + layer_idx += 1 x = rearrange(x, "b t c -> b c t").contiguous() hiddens.append(x) # Save hidden states for skip connections x = downsample(x * mask_down) @@ -263,11 +290,26 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) + if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: + if self.stg_mode == "attention": + x = transformer_block.forward_with_stg( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + else: + x = transformer_block.forward_with_stg_residual( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + else: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + layer_idx += 1 x = rearrange(x, "b t c -> b c t").contiguous() for resnet, transformer_blocks, upsample in self.up_blocks: @@ -279,11 +321,26 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) + if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: + if self.stg_mode == "attention": + x = transformer_block.forward_with_stg( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + else: + x = transformer_block.forward_with_stg_residual( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + else: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + layer_idx += 1 x = rearrange(x, "b t c -> b c t").contiguous() x = upsample(x * mask_up) x = self.final_block(x, mask_up) diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 39b3415e0..e981d903c 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -122,9 +122,65 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): return sol[-1].float() - def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): + def solve_euler_stg(self, x, t_span, mu, mask, spks, cond, streaming=False): + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) + + sol = [] + + x_in = torch.zeros([3, 80, x.size(2)], device=x.device, dtype=x.dtype) + mask_in = torch.zeros([3, 1, x.size(2)], device=x.device, dtype=x.dtype) + mu_in = torch.zeros([3, 80, x.size(2)], device=x.device, dtype=x.dtype) + t_in = torch.zeros([3], device=x.device, dtype=x.dtype) + spks_in = torch.zeros([3, 80], device=x.device, dtype=x.dtype) + cond_in = torch.zeros([3, 80, x.size(2)], device=x.device, dtype=x.dtype) + + for step in range(1, len(t_span)): + x_in[:] = x + mask_in[:] = mask + mu_in[0] = mu + mu_in[2] = mu + t_in[:] = t.unsqueeze(0) + if spks is not None: + spks_in[0] = spks + spks_in[2] = spks + if cond is not None: + cond_in[0] = cond + cond_in[2] = cond + + dphi_dt = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, + spks_in, + cond_in, + streaming=streaming, + use_stg=True) + + dphi_dt_cond, dphi_dt_uncond, dphi_dt_perturb = torch.split(dphi_dt, [x.size(0), x.size(0), x.size(0)], dim=0) + + dphi_dt = dphi_dt_uncond + 3.12 * (dphi_dt_cond - dphi_dt_uncond) + self.stg_scale * (dphi_dt_cond - dphi_dt_perturb) + + if self.do_rescaling: + rescaling_scale = 0.7 + factor = dphi_dt_cond.std() / dphi_dt.std() + factor = rescaling_scale * factor + (1 - rescaling_scale) + dphi_dt = dphi_dt * factor + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1].float() + + + def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False, use_stg=False): if isinstance(self.estimator, torch.nn.Module): - return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) + if use_stg: + return self.estimator.forward_with_stg(x, mask, mu, t, spks, cond, streaming=streaming) + else: + return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) else: [estimator, stream], trt_engine = self.estimator.acquire_estimator() with stream: @@ -192,13 +248,17 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): class CausalConditionalCFM(ConditionalCFM): - def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): + def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None, stg_applied_layers_idx=None, stg_scale=0.0, do_rescaling=False, stg_mode="attention"): super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) set_all_random_seed(0) + self.stg_applied_layers_idx = stg_applied_layers_idx or [] + self.stg_scale = stg_scale + self.do_rescaling = do_rescaling + self.stg_mode = stg_mode self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False, use_stg=False): """Forward diffusion Args: @@ -222,4 +282,7 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None + if not use_stg: + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None + else: + return self.solve_euler_stg(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None diff --git a/third_party/Matcha-TTS b/third_party/Matcha-TTS index dd9105b34..556179b33 160000 --- a/third_party/Matcha-TTS +++ b/third_party/Matcha-TTS @@ -1 +1 @@ -Subproject commit dd9105b34bf2be2230f4aa1e4769fb586a3c824e +Subproject commit 556179b337cdc207042f502b3c6ba9890cde300c From 319934911324cb2ee95b59fcc113826bc440f5b1 Mon Sep 17 00:00:00 2001 From: primepake Date: Mon, 23 Jun 2025 15:06:38 +0000 Subject: [PATCH 2/2] add stg --- cosyvoice/flow/decoder.py | 222 +++++++++++++++++++++++++++----------- 1 file changed, 159 insertions(+), 63 deletions(-) diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index ae76f2a69..5fe49dcda 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Any, Dict, Optional +import types import torch import torch.nn as nn import torch.nn.functional as F @@ -85,6 +86,119 @@ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): self.block2 = CausalBlock1D(dim_out, dim_out) + +def forward_with_stg( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, +) -> torch.FloatTensor: + + num_prompt = hidden_states.size(0) // 3 + hidden_states_ptb = hidden_states[2 * num_prompt:] + + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + hidden_states[2*num_prompt:] = hidden_states_ptb + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + return hidden_states + +def forward_with_stg_residual( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, +) -> torch.FloatTensor: + # Split batch for perturbation (last third is perturbed) + num_prompt = hidden_states.size(0) // 3 + hidden_states_ptb = hidden_states[2 * num_prompt:] + + # Apply normal forward pass to all samples + output = self.forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # Replace perturbed samples with their input (residual skip) + output[2 * num_prompt:] = hidden_states_ptb + + return output + + + class ConditionalDecoder(nn.Module): def __init__( self, @@ -128,6 +242,8 @@ def __init__( self.mid_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) + layer_idx = 0 + output_channel = in_channels for i in range(len(channels)): # pylint: disable=consider-using-enumerate input_channel = output_channel @@ -146,6 +262,15 @@ def __init__( for _ in range(n_blocks) ] ) + # Bind STG methods to transformer blocks if applicable + for block in transformer_blocks: + if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: + if self.stg_mode == "attention": + block.forward = types.MethodType(forward_with_stg, block) + else: # residual + block.forward = types.MethodType(forward_with_stg_residual, block) + layer_idx += 1 + downsample = ( Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) @@ -168,6 +293,14 @@ def __init__( for _ in range(n_blocks) ] ) + # Bind STG methods to transformer blocks if applicable + for block in transformer_blocks: + if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: + if self.stg_mode == "attention": + block.forward = types.MethodType(forward_with_stg, block) + else: # residual + block.forward = types.MethodType(forward_with_stg_residual, block) + layer_idx += 1 self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) @@ -193,6 +326,14 @@ def __init__( for _ in range(n_blocks) ] ) + for block in transformer_blocks: + if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: + if self.stg_mode == "attention": + block.forward = types.MethodType(forward_with_stg, block) + else: # residual + block.forward = types.MethodType(forward_with_stg_residual, block) + layer_idx += 1 + upsample = ( Upsample1D(output_channel, use_conv_transpose=True) if not is_last @@ -248,7 +389,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): hiddens = [] masks = [mask] - layer_idx = 0 + for resnet, transformer_blocks, downsample in self.down_blocks: mask_down = masks[-1] @@ -256,27 +397,12 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): x = rearrange(x, "b c t -> b t c").contiguous() attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) - for transformer_block in transformer_blocks: - if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: - if self.stg_mode == "attention": - x = transformer_block.forward_with_stg( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - else: - x = transformer_block.forward_with_stg_residual( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - else: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - layer_idx += 1 + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) x = rearrange(x, "b t c -> b c t").contiguous() hiddens.append(x) # Save hidden states for skip connections x = downsample(x * mask_down) @@ -290,26 +416,11 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: - if self.stg_mode == "attention": - x = transformer_block.forward_with_stg( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - else: - x = transformer_block.forward_with_stg_residual( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - else: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - layer_idx += 1 + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) x = rearrange(x, "b t c -> b c t").contiguous() for resnet, transformer_blocks, upsample in self.up_blocks: @@ -321,26 +432,11 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx: - if self.stg_mode == "attention": - x = transformer_block.forward_with_stg( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - else: - x = transformer_block.forward_with_stg_residual( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - else: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - layer_idx += 1 + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) x = rearrange(x, "b t c -> b c t").contiguous() x = upsample(x * mask_up) x = self.final_block(x, mask_up)