Skip to content

Commit 4396eab

Browse files
feat: add xy embedding to alignment
1 parent 4c2e66a commit 4396eab

File tree

6 files changed

+146
-65
lines changed

6 files changed

+146
-65
lines changed

README.md

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

2-
# Aligner - PyTorch
2+
# Aligner - PyTorch
33

4-
Sequence alignement methods with helpers for PyTorch.
4+
Sequence alignement methods with helpers for PyTorch.
55

66
## Install
77

@@ -16,12 +16,12 @@ pip install aligner-pytorch
1616

1717
### MAS
1818

19-
MAS (Monotonic Alignment Search) from GlowTTS. This can be used to get the alignment of any (similarity) matrix. Implementation in optimized Cython.
19+
MAS (Monotonic Alignment Search) from GlowTTS. This can be used to get the alignment of any (similarity) matrix. Implementation in optimized Cython.
2020

2121
```py
22-
from aligner_pytorch import mas
22+
from aligner_pytorch import mas
2323

24-
sim = torch.rand(1, 4, 6) # [batch_size, m_rows, n_cols]
24+
sim = torch.rand(1, 4, 6) # [batch_size, x_length, y_length]
2525
alignment = mas(sim)
2626

2727
"""
@@ -41,14 +41,60 @@ alignment = tensor([[
4141
"""
4242
```
4343

44+
### XY Embedding to Alignment
45+
Used during training to get the alignement of a `x_embedding` with `y_embedding`, computes the log probability from a normal distribution and the alignment with MAS.
46+
```py
47+
from aligner_pytorch import get_alignment_from_embeddings
48+
49+
x_embedding = torch.randn(1, 4, 10)
50+
y_embedding = torch.randn(1, 6, 10)
51+
52+
alignment = get_alignment_from_embeddings(
53+
x_embedding=torch.randn(1, 4, 10), # [batch_size, x_length, features]
54+
y_embedding=torch.randn(1, 6, 10), # [batch_size, y_length, features]
55+
) # [batch_size, x_length, y_length]
56+
57+
"""
58+
alignment = tensor([[
59+
[1, 0, 0, 0, 0, 0],
60+
[0, 1, 0, 0, 0, 0],
61+
[0, 0, 1, 0, 0, 0],
62+
[0, 0, 0, 1, 1, 1]
63+
]], dtype=torch.int32)
64+
"""
65+
```
66+
67+
### Duration Embedding to Alignment
68+
Used during inference to compute the alignment from a trained duration embedding.
69+
```py
70+
from aligner_pytorch import get_alignment_from_duration_embedding
71+
72+
alignment = get_alignment_from_duration_embedding(
73+
embedding=torch.randn(1, 5), # Embedding: [batch_size, x_length]
74+
scale=1.0, # Duration scale
75+
y_length=10 # (Optional) fixes maximum output y_length
76+
) # Output alignment [batch_size, x_length, y_length]
77+
78+
"""
79+
alignment = tensor([[
80+
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
81+
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
82+
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
83+
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
84+
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
85+
]])
86+
"""
87+
```
88+
89+
4490
## Citations
4591

46-
Monotonic Alignment Search
92+
Monotonic Alignment Search
4793
```bibtex
4894
@misc{2005.11129,
4995
Author = {Jaehyeon Kim and Sungwon Kim and Jungil Kong and Sungroh Yoon},
5096
Title = {Glow-TTS: A Generative Flow for Text-to-Speech via Monotonic Alignment Search},
5197
Year = {2020},
5298
Eprint = {arXiv:2005.11129},
5399
}
54-
```
100+
```

aligner_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from aligner_pytorch.mas import mas
2+
from .aligner import * # noqa

aligner_pytorch/aligner.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import math
3+
from torch.nn import functional as F
4+
from torch import Tensor
5+
from typing import Optional
6+
from einops import rearrange, reduce, repeat
7+
from .utils import exists
8+
from .mas import mas
9+
10+
11+
@torch.no_grad()
12+
def get_alignment_from_embeddings(
13+
x_embedding: Tensor,
14+
y_embedding: Tensor,
15+
x_mask: Optional[Tensor] = None,
16+
) -> Tensor:
17+
_, ty, d = y_embedding.shape
18+
# Compute multivariate gaussian log PDF: log N(x|mu=y, Σ=I)
19+
const = -0.5 * math.log(2 * math.pi) * d
20+
factor = -0.5 * torch.ones(x_embedding.shape).to(x_embedding)
21+
y_sq = torch.einsum("b x d, b y d -> b x y", factor, y_embedding**2)
22+
y_mu = torch.einsum("b x d, b y d -> b x y", 2 * factor * x_embedding, y_embedding)
23+
x_sq = reduce(factor * x_embedding**2, "b tx d -> b tx 1", "sum")
24+
log_prior = y_sq - y_mu + x_sq + const
25+
# Mask xs if provided
26+
a_mask = repeat(x_mask, "b tx -> b tx ty", ty=ty) if exists(x_mask) else None
27+
# Compute MAS alignment
28+
alignment = mas(log_prior, mask=a_mask)
29+
return alignment
30+
31+
32+
def get_sequential_masks(lengths: Tensor, length_max: Optional[int] = None) -> Tensor:
33+
if not exists(length_max):
34+
length_max = int(lengths.max().item())
35+
length_max = int(length_max)
36+
x = rearrange(torch.arange(length_max).to(lengths), "n -> 1 n")
37+
y = rearrange(lengths, "b -> b 1")
38+
return x < y
39+
40+
41+
def get_alignment_from_duration(
42+
duration: Tensor,
43+
mask: Tensor,
44+
) -> Tensor:
45+
b, tx, ty = mask.shape
46+
duration_cum = torch.cumsum(duration, dim=1)
47+
# Compute paths matrix filled with True on the lower diagonal
48+
paths = get_sequential_masks(
49+
lengths=rearrange(duration_cum, "b tx -> (b tx)"), length_max=ty
50+
)
51+
paths = rearrange(paths, "(b tx) ty -> b tx ty", b=b)
52+
# Get mask paths matrix to get only a single path by padding top and inverting
53+
paths_mask = ~F.pad(paths, pad=(0, 0, 1, 0))[:, :-1, :]
54+
# Get single path and mask unused
55+
paths = paths * paths_mask * mask
56+
return paths.long()
57+
58+
59+
@torch.no_grad()
60+
def get_alignment_from_duration_embedding(
61+
embedding: Tensor, # [b, tx]
62+
scale: float = 1.0,
63+
mask: Optional[Tensor] = None, # [b, tx]
64+
y_length: Optional[int] = None,
65+
) -> Tensor: # [b, tx, ty]
66+
b, tx, device = *embedding.shape, embedding.device
67+
# Default mask to all xs if not provided
68+
x_mask = mask if exists(mask) else torch.ones((b, tx), device=device).bool()
69+
assert x_mask.shape == embedding.shape, "mask must have same shape as embedding"
70+
# Get int duration by exponentiating and ceiling, then scaling by duration scale
71+
duration = torch.exp(embedding)
72+
duration = torch.ceil(duration) * scale
73+
duration = duration * x_mask
74+
# Compute total duration per item (clamp if below 1)
75+
duration_total = torch.clamp_min(reduce(duration, "b tx -> b", "sum"), 1).long()
76+
# Get max duration over all items
77+
duration_max = y_length if exists(y_length) else int(duration_total.max())
78+
# Get ys mask and attn matrix mask
79+
y_mask = get_sequential_masks(lengths=duration_total, length_max=duration_max) # type: ignore # noqa
80+
a_mask = rearrange(x_mask, "b tx -> b tx 1") * rearrange(y_mask, "b ty -> b 1 ty")
81+
# Get masked attn paths from duration
82+
return get_alignment_from_duration(duration=duration, mask=a_mask)

aligner_pytorch/mas.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,22 @@
22
import torch
33
from torch import Tensor
44
from typing import Optional
5+
from .utils import exists
56

67

78
def mas(x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
8-
b, m, n, device = *x.shape, x.device
9+
device = x.device
910

1011
values = x.detach().clone().to(dtype=torch.float32, device="cpu").numpy()
1112
paths = torch.zeros_like(x, dtype=torch.int32, device="cpu").numpy()
1213

14+
mask = mask.clone() if exists(mask) else torch.ones_like(x)
15+
mask = mask.to(dtype=torch.int32, device="cpu").numpy()
16+
17+
# ms = reduce(mask, 'b m n -> b m', 'sum')[:, 0]
18+
# ns = reduce(mask, 'b m n -> b n', 'sum')[:, 0]
19+
20+
b, m, n = x.shape
1321
ms = torch.tensor([m], dtype=torch.int32).repeat(b).numpy()
1422
ns = torch.tensor([n], dtype=torch.int32).repeat(b).numpy()
1523

aligner_pytorch/utils.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,8 @@
1-
import torch
2-
from torch.nn import functional as F
3-
from torch import Tensor
41
from typing import TypeVar, Optional
52
from typing_extensions import TypeGuard
6-
from einops import rearrange, reduce
73

84
T = TypeVar("T")
95

106

117
def exists(val: Optional[T]) -> TypeGuard[T]:
128
return val is not None
13-
14-
15-
def get_sequential_masks(lengths: Tensor, length_max: Optional[int] = None) -> Tensor:
16-
if not exists(length_max):
17-
length_max = int(lengths.max().item())
18-
length_max = int(length_max)
19-
x = rearrange(torch.arange(length_max).to(lengths), "n -> 1 n")
20-
y = rearrange(lengths, "b -> b 1")
21-
return x < y
22-
23-
24-
def get_alignment_from_duration(
25-
duration: Tensor,
26-
mask: Tensor,
27-
) -> Tensor:
28-
b, tx, ty = mask.shape
29-
duration_cum = torch.cumsum(duration, dim=1)
30-
# Compute paths matrix filled with True on the lower diagonal
31-
paths = get_sequential_masks(
32-
lengths=rearrange(duration_cum, "b tx -> (b tx)"), length_max=ty
33-
)
34-
paths = rearrange(paths, "(b tx) ty -> b tx ty", b=b)
35-
# Get mask paths matrix to get only a single path by padding top and inverting
36-
paths_mask = ~F.pad(paths, pad=(0, 0, 1, 0))[:, :-1, :]
37-
# Get single path and mask unused
38-
paths = paths * paths_mask * mask
39-
return paths.long()
40-
41-
42-
def get_alignment_from_duration_embedding(
43-
embedding: Tensor, # [b, tx]
44-
scale: float = 1.0,
45-
mask: Optional[Tensor] = None, # [b, tx]
46-
max_length: Optional[int] = None,
47-
) -> Tensor: # [b, tx, ty]
48-
b, tx, device = *embedding.shape, embedding.device
49-
# Default mask to all xs if not provided
50-
x_mask = mask if exists(mask) else torch.ones((b, tx), device=device).bool()
51-
assert x_mask.shape == embedding.shape, "mask must have same shape as embedding"
52-
# Get int duration by exponentiating and ceiling, then scaling by duration scale
53-
duration = torch.exp(embedding)
54-
duration = torch.ceil(duration) * scale
55-
duration = duration * x_mask
56-
# Compute total duration per item (clamp if below 1)
57-
duration_total = torch.clamp_min(reduce(duration, "b tx -> b", "sum"), 1).long()
58-
# Get max duration over all items
59-
duration_max = max_length if exists(max_length) else int(duration_total.max())
60-
# Get ys mask and attn matrix mask
61-
y_mask = get_sequential_masks(lengths=duration_total, length_max=duration_max) # type: ignore # noqa
62-
a_mask = rearrange(x_mask, "b tx -> b tx 1") * rearrange(y_mask, "b ty -> b 1 ty")
63-
# Get masked attn paths from duration
64-
return get_alignment_from_duration(duration=duration, mask=a_mask)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="aligner-pytorch",
10-
version="0.0.17",
10+
version="0.0.18",
1111
packages=find_packages(),
1212
license="MIT",
1313
description="Aligner - PyTorch",

0 commit comments

Comments
 (0)