Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions cosyvoice/flow/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Any, Dict, Optional
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -85,6 +86,119 @@ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
self.block2 = CausalBlock1D(dim_out, dim_out)



def forward_with_stg(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:

num_prompt = hidden_states.size(0) // 3
hidden_states_ptb = hidden_states[2 * num_prompt:]


if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)


cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}

attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states

# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)

attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states

hidden_states[2*num_prompt:] = hidden_states_ptb

# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)

if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)

num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)

if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = ff_output + hidden_states
return hidden_states

def forward_with_stg_residual(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
# Split batch for perturbation (last third is perturbed)
num_prompt = hidden_states.size(0) // 3
hidden_states_ptb = hidden_states[2 * num_prompt:]

# Apply normal forward pass to all samples
output = self.forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)

# Replace perturbed samples with their input (residual skip)
output[2 * num_prompt:] = hidden_states_ptb

return output



class ConditionalDecoder(nn.Module):
def __init__(
self,
Expand All @@ -97,6 +211,10 @@ def __init__(
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
stg_applied_layers_idx=None,
stg_scale=0.0,
do_rescaling=False,
stg_mode="attention"
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
Expand All @@ -114,10 +232,18 @@ def __init__(
time_embed_dim=time_embed_dim,
act_fn="silu",
)

self.stg_applied_layers_idx = stg_applied_layers_idx or []
self.stg_scale = stg_scale
self.do_rescaling = do_rescaling
self.stg_mode = stg_mode

self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])

layer_idx = 0

output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
Expand All @@ -136,6 +262,15 @@ def __init__(
for _ in range(n_blocks)
]
)
# Bind STG methods to transformer blocks if applicable
for block in transformer_blocks:
if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx:
if self.stg_mode == "attention":
block.forward = types.MethodType(forward_with_stg, block)
else: # residual
block.forward = types.MethodType(forward_with_stg_residual, block)
layer_idx += 1

downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
Expand All @@ -158,6 +293,14 @@ def __init__(
for _ in range(n_blocks)
]
)
# Bind STG methods to transformer blocks if applicable
for block in transformer_blocks:
if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx:
if self.stg_mode == "attention":
block.forward = types.MethodType(forward_with_stg, block)
else: # residual
block.forward = types.MethodType(forward_with_stg_residual, block)
layer_idx += 1

self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))

Expand All @@ -183,6 +326,14 @@ def __init__(
for _ in range(n_blocks)
]
)
for block in transformer_blocks:
if self.stg_scale > 0 and layer_idx in self.stg_applied_layers_idx:
if self.stg_mode == "attention":
block.forward = types.MethodType(forward_with_stg, block)
else: # residual
block.forward = types.MethodType(forward_with_stg_residual, block)
layer_idx += 1

upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
Expand Down Expand Up @@ -238,13 +389,15 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):

hiddens = []
masks = [mask]


for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
Expand Down
73 changes: 68 additions & 5 deletions cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,65 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):

return sol[-1].float()

def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
def solve_euler_stg(self, x, t_span, mu, mask, spks, cond, streaming=False):
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)

sol = []

x_in = torch.zeros([3, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([3, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([3, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([3], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([3, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([3, 80, x.size(2)], device=x.device, dtype=x.dtype)

for step in range(1, len(t_span)):
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
mu_in[2] = mu
t_in[:] = t.unsqueeze(0)
if spks is not None:
spks_in[0] = spks
spks_in[2] = spks
if cond is not None:
cond_in[0] = cond
cond_in[2] = cond

dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in,
streaming=streaming,
use_stg=True)

dphi_dt_cond, dphi_dt_uncond, dphi_dt_perturb = torch.split(dphi_dt, [x.size(0), x.size(0), x.size(0)], dim=0)

dphi_dt = dphi_dt_uncond + 3.12 * (dphi_dt_cond - dphi_dt_uncond) + self.stg_scale * (dphi_dt_cond - dphi_dt_perturb)

if self.do_rescaling:
rescaling_scale = 0.7
factor = dphi_dt_cond.std() / dphi_dt.std()
factor = rescaling_scale * factor + (1 - rescaling_scale)
dphi_dt = dphi_dt * factor

x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t

return sol[-1].float()


def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False, use_stg=False):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
if use_stg:
return self.estimator.forward_with_stg(x, mask, mu, t, spks, cond, streaming=streaming)
else:
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else:
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
with stream:
Expand Down Expand Up @@ -192,13 +248,17 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):


class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None, stg_applied_layers_idx=None, stg_scale=0.0, do_rescaling=False, stg_mode="attention"):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
set_all_random_seed(0)
self.stg_applied_layers_idx = stg_applied_layers_idx or []
self.stg_scale = stg_scale
self.do_rescaling = do_rescaling
self.stg_mode = stg_mode
self.rand_noise = torch.randn([1, 80, 50 * 300])

@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False, use_stg=False):
"""Forward diffusion

Args:
Expand All @@ -222,4 +282,7 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None,
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
if not use_stg:
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
else:
return self.solve_euler_stg(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
2 changes: 1 addition & 1 deletion third_party/Matcha-TTS