Skip to content

Commit fdf55ba

Browse files
feat: add number embedder
1 parent dbe3b91 commit fdf55ba

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ sampled = model.sample(
130130
)
131131
```
132132

133+
#### Number Conditional Generation
134+
135+
```py
136+
from audio_diffusion_pytorch import NumberEmbedder
137+
138+
embedder = NumberEmbedder(features=768)
139+
embedding = embedder([0.1, 0.2]) # [2, 768]
140+
```
141+
142+
133143
## Usage with Components
134144

135145
### UNet1d

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@
3333
DiffusionVocoder1d,
3434
Model1d,
3535
)
36-
from .modules import T5Embedder, UNet1d, UNetConditional1d
36+
from .modules import NumberEmbedder, T5Embedder, UNet1d, UNetConditional1d

audio_diffusion_pytorch/modules.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,28 @@ def forward(self, texts: List[str]) -> Tensor:
12711271
return embedding
12721272

12731273

1274+
class NumberEmbedder(nn.Module):
1275+
def __init__(
1276+
self,
1277+
features: int,
1278+
dim: int = 256,
1279+
):
1280+
super().__init__()
1281+
self.features = features
1282+
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1283+
1284+
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1285+
if not torch.is_tensor(x):
1286+
device = next(self.embedding.parameters()).device
1287+
x = torch.tensor(x, device=device)
1288+
assert isinstance(x, Tensor)
1289+
shape = x.shape
1290+
x = rearrange(x, "... -> (...)")
1291+
embedding = self.embedding(x)
1292+
x = embedding.view(*shape, self.features)
1293+
return x # type: ignore
1294+
1295+
12741296
"""
12751297
Audio Transforms
12761298
"""

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.86",
6+
version="0.0.87",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)