Skip to content

Commit 4609ea1

Browse files
medbarAnton Mitrofanov
andauthored
[egs] LT-LM recipe for librispeech (#4590)
Co-authored-by: Anton Mitrofanov <[email protected]>
1 parent 7460d99 commit 4609ea1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+4926
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# LT-LM: a novel non-autoregressive language model for single-shot lattice rescoring
2+
[Paper](https://arxiv.org/pdf/2104.02526.pdf)
3+
4+
## Setup:
5+
`cd fairseq_ltlm && setup.sh`
6+
## run:
7+
* put slurm.conf to conf/
8+
* modify fairseq\_ltlm/recipes/config.sh if needed
9+
* `bash fairseq\_ltlm/recipes/run.sh`
10+
## evaluate:
11+
For evaluation, you can
12+
run fairseq\_ltlm/recipes/run\_5\_eval.sh (see run.sh) or use fairseq\_ltlm/ltlm/eval.py directly.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
kaldi=../../../../..
2+
all:
3+
4+
include $(kaldi)/src/kaldi.mk
5+
6+
EXTRA_CXXFLAGS += -Wno-sign-compare
7+
EXTRA_CXXFLAGS += -I$(kaldi)/src
8+
BINFILES = latgen-faster-mapped-fake-am
9+
10+
OBJFILES =
11+
12+
TESTFILES =
13+
14+
ADDLIBS = $(kaldi)/src/decoder/kaldi-decoder.a $(kaldi)/src/lat/kaldi-lat.a $(kaldi)/src/lm/kaldi-lm.a \
15+
$(kaldi)/src/fstext/kaldi-fstext.a $(kaldi)/src/hmm/kaldi-hmm.a \
16+
$(kaldi)/src/transform/kaldi-transform.a $(kaldi)/src/gmm/kaldi-gmm.a \
17+
$(kaldi)/src/tree/kaldi-tree.a $(kaldi)/src/util/kaldi-util.a $(kaldi)/src/matrix/kaldi-matrix.a \
18+
$(kaldi)/src/base/kaldi-base.a
19+
20+
21+
include $(kaldi)/src/makefiles/default_rules.mk
22+
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
2+
// Copyright (c) 2021, Speech Technology Center Ltd. All rights reserved.
3+
// Anton Mitrofanov
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
// It is latgen-faster-mapped adopted to fake lattice generation
18+
19+
#include <chrono>
20+
21+
#include "base/kaldi-common.h"
22+
#include "base/timer.h"
23+
#include "decoder/decodable-matrix.h"
24+
#include "decoder/decoder-wrappers.h"
25+
#include "fstext/fstext-lib.h"
26+
#include "hmm/transition-model.h"
27+
#include "tree/context-dep.h"
28+
#include "util/common-utils.h"
29+
30+
using namespace kaldi;
31+
typedef kaldi::int32 int32;
32+
using fst::Fst;
33+
using fst::StdArc;
34+
using fst::SymbolTable;
35+
36+
int main(int argc, char *argv[]) {
37+
try {
38+
const char *usage =
39+
"Generate lattices, reading emulating am as matrices\n"
40+
" (model is needed only for the integer mappings in its "
41+
"transition-model)\n"
42+
"Usage: latgen-faster-mapped-fake-am [options] trans-model-in fst-in "
43+
"fam-rspecifier ali_rspecifier"
44+
" lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n";
45+
ParseOptions po(usage);
46+
Timer timer;
47+
bool allow_partial = false;
48+
BaseFloat acoustic_scale = 0.1;
49+
LatticeFasterDecoderConfig config;
50+
51+
std::string word_syms_filename;
52+
config.Register(&po);
53+
po.Register("acoustic-scale", &acoustic_scale,
54+
"Scaling factor for acoustic likelihoods");
55+
56+
po.Register("word-symbol-table", &word_syms_filename,
57+
"Symbol table for words [for debug output]");
58+
po.Register("allow-partial", &allow_partial,
59+
"If true, produce output even if end state was not reached.");
60+
61+
po.Read(argc, argv);
62+
63+
if (po.NumArgs() < 5 || po.NumArgs() > 7) {
64+
po.PrintUsage();
65+
exit(1);
66+
}
67+
68+
std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2),
69+
fam_rspecifier = po.GetArg(3), ali_rspecifier = po.GetArg(4),
70+
lattice_wspecifier = po.GetArg(5),
71+
words_wspecifier = po.GetOptArg(6),
72+
alignment_wspecifier = po.GetOptArg(7);
73+
74+
TransitionModel trans_model;
75+
ReadKaldiObject(model_in_filename, &trans_model);
76+
77+
bool determinize = config.determinize_lattice;
78+
CompactLatticeWriter compact_lattice_writer;
79+
LatticeWriter lattice_writer;
80+
if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier)
81+
: lattice_writer.Open(lattice_wspecifier)))
82+
KALDI_ERR << "Could not open table for writing lattices: "
83+
<< lattice_wspecifier;
84+
85+
Int32VectorWriter words_writer(words_wspecifier);
86+
87+
Int32VectorWriter alignment_writer(alignment_wspecifier);
88+
89+
fst::SymbolTable *word_syms = NULL;
90+
if (word_syms_filename != "")
91+
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
92+
KALDI_ERR << "Could not read symbol table from file "
93+
<< word_syms_filename;
94+
95+
double tot_like = 0.0;
96+
kaldi::int64 frame_count = 0;
97+
int num_success = 0, num_fail = 0;
98+
99+
// Reading Fake acoustic model form ark file
100+
KALDI_LOG << "Loading Fake Acoustic Model";
101+
SequentialBaseFloatMatrixReader fam_model_read(fam_rspecifier);
102+
std::string fam_model_key = fam_model_read.Key();
103+
Matrix<BaseFloat> fam_model(fam_model_read.Value());
104+
KALDI_LOG << "Apply log.";
105+
fam_model.ApplyLog();
106+
107+
if (fam_model_key != "fam_model") {
108+
KALDI_ERR << fam_rspecifier << " - Wrong fam_model.";
109+
po.PrintUsage();
110+
exit(1);
111+
}
112+
KALDI_LOG << "Fake Acoustic is loaded. Shape is (" << fam_model.NumRows()
113+
<< ", " << fam_model.NumCols() << ")";
114+
115+
SequentialInt32VectorReader ali_reader(ali_rspecifier);
116+
if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) {
117+
// SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier);
118+
// Input FST is just one FST, not a table of FSTs.
119+
Fst<StdArc> *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str);
120+
timer.Reset();
121+
122+
{
123+
LatticeFasterDecoder decoder(*decode_fst, config);
124+
for (; !ali_reader.Done(); ali_reader.Next()) {
125+
std::string utt = ali_reader.Key();
126+
std::vector<int32> ali(ali_reader.Value());
127+
KALDI_LOG << "Process " << utt << ". " << ali.size() << " frames";
128+
ali_reader.FreeCurrent();
129+
if (ali.size() == 0) {
130+
KALDI_WARN << "Zero-length utterance: " << utt;
131+
num_fail++;
132+
continue;
133+
}
134+
// Inference fake AM
135+
kaldi::Matrix<BaseFloat> loglikes(ali.size(), fam_model.NumRows());
136+
loglikes.SetZero();
137+
for (int i = 0; i < ali.size(); i++) {
138+
int32 pdf_id = ali[i];
139+
loglikes.CopyRowFromVec(fam_model.Row(pdf_id), i);
140+
SubVector<BaseFloat> row(loglikes, i);
141+
}
142+
// end
143+
DecodableMatrixScaledMapped decodable(trans_model, loglikes,
144+
acoustic_scale);
145+
146+
double like;
147+
if (DecodeUtteranceLatticeFaster(
148+
decoder, decodable, trans_model, word_syms, utt,
149+
acoustic_scale, determinize, allow_partial, &alignment_writer,
150+
&words_writer, &compact_lattice_writer, &lattice_writer,
151+
&like)) {
152+
tot_like += like;
153+
frame_count += loglikes.NumRows();
154+
num_success++;
155+
} else
156+
num_fail++;
157+
}
158+
}
159+
delete decode_fst; // delete this only after decoder goes out of scope.
160+
} else { // We have different FSTs for different utterances.
161+
KALDI_LOG << "FSTs not implemented yet.";
162+
exit(1);
163+
// SequentialTableReader<fst::VectorFstHolder>
164+
// fst_reader(fst_in_str); RandomAccessBaseFloatMatrixReader
165+
// loglike_reader(feature_rspecifier); for (; !fst_reader.Done();
166+
// fst_reader.Next()) {
167+
// std::string utt = fst_reader.Key();
168+
// if (!loglike_reader.HasKey(utt)) {
169+
// KALDI_WARN << "Not decoding utterance " << utt
170+
// << " because no loglikes available.";
171+
// num_fail++;
172+
// continue;
173+
// }
174+
// const Matrix<BaseFloat> &loglikes = loglike_reader.Value(utt);
175+
// if (loglikes.NumRows() == 0) {
176+
// KALDI_WARN << "Zero-length utterance: " << utt;
177+
// num_fail++;
178+
// continue;
179+
// }
180+
// LatticeFasterDecoder decoder(fst_reader.Value(), config);
181+
// DecodableMatrixScaledMapped decodable(trans_model, loglikes,
182+
// acoustic_scale); double like; if (DecodeUtteranceLatticeFaster(
183+
// decoder, decodable, trans_model, word_syms, utt,
184+
// acoustic_scale, determinize, allow_partial,
185+
// &alignment_writer, &words_writer,
186+
// &compact_lattice_writer, &lattice_writer, &like)) {
187+
// tot_like += like;
188+
// frame_count += loglikes.NumRows();
189+
// num_success++;
190+
// } else num_fail++;
191+
// }
192+
}
193+
194+
double elapsed = timer.Elapsed();
195+
KALDI_LOG << "Time taken " << elapsed
196+
<< "s: real-time factor assuming 100 frames/sec is "
197+
<< (elapsed * 100.0 / frame_count);
198+
KALDI_LOG << "Done " << num_success << " utterances, failed for "
199+
<< num_fail;
200+
KALDI_LOG << "Overall log-likelihood per frame is "
201+
<< (tot_like / frame_count) << " over " << frame_count
202+
<< " frames.";
203+
204+
delete word_syms;
205+
if (num_success != 0)
206+
return 0;
207+
else
208+
return 1;
209+
} catch (const std::exception &e) {
210+
std::cerr << e.what();
211+
return -1;
212+
}
213+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2021 STC-Innovation LTD (Author: Anton Mitrofanov)
2+
import argparse
3+
import logging
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class WordTokenizer:
9+
""" A words.txt mapping"""
10+
@staticmethod
11+
def add_args(parser: argparse.ArgumentParser):
12+
parser.add_argument('--tokenizer_fn', type=str, required=True,
13+
help='Tokenizer fname( words.txt)')
14+
#parser.add_argument('--unk', default='<UNK>', help="Unk word") # fairseq bug
15+
16+
@staticmethod
17+
def build_from_args(args):
18+
kwargs = {"fname": args.tokenizer_fn}
19+
#'unk': args.unk_word}
20+
21+
return WordTokenizer(**kwargs)
22+
23+
def __init__(self, fname, unk_word='<UNK>'):
24+
logger.info(f'Loading WordTokenizer {fname}')
25+
with open(fname, 'r', encoding='utf-8') as f:
26+
self.word2id = {w: int(i) for w, i in map(str.split, f.readlines())}
27+
self.id2word = ['']*(max(self.word2id.values()) + 1)
28+
self.unk=unk_word
29+
if self.unk not in self.word2id:
30+
if self.unk.lower() in self.word2id:
31+
self.unk=self.unk.lower()
32+
else:
33+
raise f"unk word {unk_word} not in {fname}"
34+
for w, i in self.word2id.items():
35+
self.id2word[i] = w
36+
assert self.word2id['<eps>'] == 0 and \
37+
'<s>' in self.word2id.keys() and \
38+
'</s>' in self.word2id.keys(), RuntimeError("<esp>!=0")
39+
40+
self.real_words_ids = [i for w, i in self.word2id.items() \
41+
if w.find('<') == w.find('>') == w.find('#') == w.find('!') == w.find('[') == w.find(']') == -1 and \
42+
not w.endswith('-') and not w.startswith("-") ]
43+
44+
logger.info(f'WordTokenizer {fname} loaded. Vocab size {len(self)}.')
45+
self.disambig_word_ids = [i for w, i in self.word2id.items() \
46+
if (w != "<s>" and w != "</s>") and (
47+
w.find('<') != -1 or
48+
w.find('>') != -1 or
49+
w.find('#') != -1 or
50+
w.find('!') != -1 or
51+
w.find('[') != -1 or
52+
w.find(']') != -1 or
53+
w.endswith('-') or
54+
w.startswith('-'))]
55+
logger.info(f"WordTokenizer Disambig ids: {self.disambig_word_ids}")
56+
logger.info(f"WordTokenizer Disambig words: {[ self.id2word[i] for i in self.disambig_word_ids]}")
57+
58+
def encode(self, text, bos=False, eos=False):
59+
return [
60+
([self.get_bos_id()] if bos else []) +
61+
[self.word2id[w] if w in self.word2id.keys() else self.word2id[self.unk] for w in line.split()] +
62+
([self.get_eos_id()] if eos else []) for line in text]
63+
64+
def decode(self, text_ids):
65+
return [[self.id2word[i] for i in line_ids] for line_ids in text_ids]
66+
67+
def __len__(self):
68+
return len(self.id2word)
69+
70+
def get_real_words_ids(self):
71+
return self.real_words_ids
72+
73+
def get_disambig_words_ids(self):
74+
return self.disambig_word_ids
75+
76+
def get_bos_id(self):
77+
return self.word2id["<s>"]
78+
79+
def get_eos_id(self):
80+
return self.word2id["</s>"]
81+
82+
def get_unk_id(self):
83+
return self.word2id[self.unk]
84+
85+
def pad(self):
86+
return self.word2id['<eps>']
87+
88+
def print_lat(self, lat, print_word_id=False, p=None):
89+
for i, arc in enumerate(lat):
90+
out_str = f"{arc[1]} {arc[2]} {arc[0] if print_word_id else self.id2word[arc[0]]}"
91+
if p is not None:
92+
out_str += f" {p[i]}"
93+
print(out_str)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2021 STC-Innovation LTD (Author: Anton Mitrofanov)
2+
from .criterions.bce_loss import BCECriterion
3+
from .datasets import LatsOracleAlignDataSet
4+
from .models import LTLM
5+
from .tasks.rescoring_task import RescoringTask

0 commit comments

Comments
 (0)