Skip to content

Commit f27a991

Browse files
committed
fix: extract freture cannot run on pure cpu
1 parent 9e59375 commit f27a991

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

extract_feature_print.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def readwave(wav_path, normalize=False):
5151
)
5252
model = models[0]
5353
model = model.to(device)
54-
model = model.half()
54+
if torch.cuda.is_available():
55+
model = model.half()
5556
model.eval()
5657

5758
todo=sorted(list(os.listdir(wavPath)))[i_part::n_part]
@@ -70,7 +71,7 @@ def readwave(wav_path, normalize=False):
7071
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
7172
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
7273
inputs = {
73-
"source": feats.half().to(device),
74+
"source": feats.half().to(device) if torch.cuda.is_available() else feats.to(device),
7475
"padding_mask": padding_mask.to(device),
7576
"output_layer": 9, # layer 9
7677
}

0 commit comments

Comments
 (0)