Skip to content

Commit 315ab44

Browse files
feat: add pretrained/frozen T5 embedder
1 parent d16dfa7 commit 315ab44

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

README.md

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,19 @@ from audio_diffusion_pytorch import AudioDiffusionConditional
9393

9494
model = AudioDiffusionConditional(
9595
in_channels=1,
96-
embedding_max_length=512,
96+
embedding_max_length=64,
9797
embedding_features=768,
9898
embedding_mask_proba=0.1 # Conditional dropout of batch elements
9999
)
100100

101101
# Train on pairs of audio and embedding data (e.g. from a transformer output)
102102
x = torch.randn(2, 1, 2 ** 18)
103-
embedding = torch.randn(2, 512, 768)
103+
embedding = torch.randn(2, 64, 768)
104104
loss = model(x, embedding=embedding)
105105
loss.backward()
106106

107107
# Given start embedding and noise sample new source
108-
embedding = torch.randn(1, 512, 768)
108+
embedding = torch.randn(1, 64, 768)
109109
noise = torch.randn(1, 1, 2 ** 18)
110110
sampled = model.sample(
111111
noise,
@@ -115,6 +115,25 @@ sampled = model.sample(
115115
) # [1, 1, 2 ** 18]
116116
```
117117

118+
#### Text Conditional Generation
119+
You can generate embeddings from text by using a pretrained frozen T5 transformer with `T5Embedder`, as follows (note that this requires `pip install transformers`):
120+
121+
```py
122+
from audio_diffusion_pytorch import T5Embedder
123+
124+
embedder = T5Embedder(model='t5-base', max_length=64)
125+
embedding = embedder(["First batch item text...", "Second batch item text..."])
126+
127+
loss = model(x, embedding=embedding)
128+
# ...
129+
sampled = model.sample(
130+
noise,
131+
embedding=embedding,
132+
embedding_scale=5.0, # Classifier-free guidance scale
133+
num_steps=5
134+
)
135+
```
136+
118137
## Usage with Components
119138

120139
### UNet1d

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .modules import (
2828
AutoEncoder1d,
2929
MultiEncoder1d,
30+
T5Embedder,
3031
UNet1d,
3132
UNetConditional1d,
3233
Variational,

audio_diffusion_pytorch/modules.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,9 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
10821082
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
10831083

10841084

1085+
""" Conditioning """
1086+
1087+
10851088
class UNetConditional1d(UNet1d):
10861089
"""
10871090
UNet1d with classifier-free guidance on the token embeddings
@@ -1130,6 +1133,35 @@ def forward( # type: ignore
11301133
return out
11311134

11321135

1136+
class T5Embedder(nn.Module):
1137+
def __init__(self, model: str = "t5-base", max_length: int = 64):
1138+
super().__init__()
1139+
from transformers import T5EncoderModel, T5Tokenizer
1140+
1141+
self.tokenizer = T5Tokenizer.from_pretrained(model)
1142+
self.transformer = T5EncoderModel.from_pretrained(model)
1143+
self.max_length = max_length
1144+
1145+
@torch.no_grad()
1146+
def forward(self, texts: List[str]) -> Tensor:
1147+
1148+
encoded = self.tokenizer(
1149+
texts,
1150+
truncation=True,
1151+
max_length=self.max_length,
1152+
padding="max_length",
1153+
return_tensors="pt",
1154+
)
1155+
1156+
self.transformer.eval()
1157+
1158+
embedding = self.transformer(
1159+
input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"]
1160+
)["last_hidden_state"]
1161+
1162+
return embedding
1163+
1164+
11331165
"""
11341166
Encoders / Decoders
11351167
"""

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.61",
6+
version="0.0.62",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)