Skip to content

Commit 38fddaf

Browse files
feat: show embed shape
1 parent 9f4f296 commit 38fddaf

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ loss = model(x, embedding=embedding)
105105
loss.backward()
106106

107107
# Given start embedding and noise sample new source
108-
embedding = torch.randn(1, 64, 768)
109-
noise = torch.randn(1, 1, 2 ** 18)
108+
embedding = torch.randn(2, 64, 768)
109+
noise = torch.randn(2, 1, 2 ** 18)
110110
sampled = model.sample(
111111
noise,
112112
embedding=embedding,
113113
embedding_scale=5.0, # Classifier-free guidance scale
114114
num_steps=5
115-
) # [1, 1, 2 ** 18]
115+
) # [2, 1, 2 ** 18]
116116
```
117117

118118
#### Text Conditional Generation
@@ -122,7 +122,7 @@ You can generate embeddings from text by using a pretrained frozen T5 transforme
122122
from audio_diffusion_pytorch import T5Embedder
123123

124124
embedder = T5Embedder(model='t5-base', max_length=64)
125-
embedding = embedder(["First batch item text...", "Second batch item text..."]) # [1, 64, 768]
125+
embedding = embedder(["First batch item text...", "Second batch item text..."]) # [2, 64, 768]
126126

127127
loss = model(x, embedding=embedding)
128128
# ...

0 commit comments

Comments
 (0)