We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9e59375 commit f27a991Copy full SHA for f27a991
extract_feature_print.py
@@ -51,7 +51,8 @@ def readwave(wav_path, normalize=False):
51
)
52
model = models[0]
53
model = model.to(device)
54
-model = model.half()
+if torch.cuda.is_available():
55
+ model = model.half()
56
model.eval()
57
58
todo=sorted(list(os.listdir(wavPath)))[i_part::n_part]
@@ -70,7 +71,7 @@ def readwave(wav_path, normalize=False):
70
71
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
72
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
73
inputs = {
- "source": feats.half().to(device),
74
+ "source": feats.half().to(device) if torch.cuda.is_available() else feats.to(device),
75
"padding_mask": padding_mask.to(device),
76
"output_layer": 9, # layer 9
77
}
0 commit comments