Skip to content

Commit 7fabeb3

Browse files
fix: t5 embedder device
1 parent d3f27f8 commit 7fabeb3

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,10 +1153,14 @@ def forward(self, texts: List[str]) -> Tensor:
11531153
return_tensors="pt",
11541154
)
11551155

1156+
device = next(self.transformer.parameters()).device
1157+
input_ids = encoded["input_ids"].to(device)
1158+
attention_mask = encoded["attention_mask"].to(device)
1159+
11561160
self.transformer.eval()
11571161

11581162
embedding = self.transformer(
1159-
input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"]
1163+
input_ids=input_ids, attention_mask=attention_mask
11601164
)["last_hidden_state"]
11611165

11621166
return embedding

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.63",
6+
version="0.0.64",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)