-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
150 lines (106 loc) · 3.78 KB
/
inference.py
File metadata and controls
150 lines (106 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import json
import argparse
from pathlib import Path
import soundfile as sf
from scipy.signal import resample
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from safetensors import safe_open
from model import XeusForCTC
def ctc_greedy_decoder(logits, vocab):
pad_id = vocab["[PAD]"]
# Apply softmax to logits to get probabilities
probs = F.softmax(logits, dim=-1)
# Inverse the vocab dictionary to map indices back to characters
index_to_char = {index: char for char, index in vocab.items()}
# Get the most probable token indices
pred_indices = torch.argmax(probs, dim=-1)
decoded_sequences = []
for indices in pred_indices:
decoded = []
prev_index = None
for index in indices:
index = index.item()
if index != prev_index:
if index != pad_id:
decoded.append(index_to_char.get(index, "[UNK]"))
prev_index = index
decoded_sequences.append("".join(decoded).replace("|", " "))
return decoded_sequences
def load_model(config, ckpt_path):
# Load model
model = XeusForCTC(config)
# Load checkpoint
with safe_open(f"{ckpt_path}/model.safetensors", framework="pt") as f:
state_dict = {}
for key in f.keys():
state_dict[key] = f.get_tensor(key)
model.load_state_dict(state_dict)
# Set model to evaluation mode
model.eval()
return model
def perform_inference(model, wavs, vocab):
# Tokenize input text
wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]])
wavs = pad_sequence(torch.Tensor([wavs]), batch_first=True)
with torch.inference_mode():
_, logits, _ = model(wavs, None, wav_lengths)
# Get prediction
prediction = ctc_greedy_decoder(logits, vocab)
return prediction
def load_vocab(vocab_path):
with open(vocab_path, "r") as f:
vocab = json.load(f)
return vocab
class Config:
def __init__(self, config_dict):
for key, value in config_dict.items():
setattr(self, key, value)
def read_and_resample_wav(wav_path, target_sr=16_000):
# Load the audio file
y, sr = sf.read(wav_path)
# Resample the audio to the target sampling rate
if sr != target_sr:
num_samples = int(len(y) * target_sr / sr)
y = resample(y, num_samples)
sr = target_sr
return y, sr
# Example usage:
def main(args):
# Check if checkpoint exists
ckpt_path = Path(args.ckpt_path)
if not os.path.exists(ckpt_path):
print(f"Checkpoint path '{ckpt_path}' does not exist.")
return
# Get the parent directory of the checkpoint path
parent_dir = ckpt_path.parent
# Load vocab.json from the parent directory
vocab_path = parent_dir / "vocab.json"
if not vocab_path.is_file():
print(f"vocab.json not found in '{parent_dir}'.")
return
vocab_dict = load_vocab(vocab_path)
dummy_config = {
"vocab_size": len(vocab_dict),
"pad_token_id": vocab_dict["[PAD]"],
"pretrained_model_path": "./XEUS/model/xeus_checkpoint.pth",
"final_dropout": 0.1,
"hidden_size": 1024,
}
config = Config(dummy_config)
# Load the model
model = load_model(config, args.ckpt_path)
audio, _ = read_and_resample_wav(args.audio, target_sr=16_000)
prediction = perform_inference(model, audio, vocab_dict)
print(prediction)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Inference script for XUES model")
parser.add_argument(
"--ckpt_path", type=str, required=True, help="Path to the checkpoint file"
)
parser.add_argument(
"--audio", type=str, required=True, help="Path to the audio fle"
)
main(parser.parse_args())