Skip to content

Commit 638376b

Browse files
authored
Show how to decode in python with kaldi pybind and wrap warp-ctc (#3834)
* add pybind wrapper for reading int32 vector and some IO util for reading vec/mat. * add pybind wrapper for LatticeArc. * refactor OpenFST pybind to use template. * [pybind] add pybind wrapper for Lattice and CompactLattice. * [pybind] add pybind wrapper for Lattice/CompactLattice table I/O. * [pybind] add pybind wrapper for LatticeFasterDecoderConfig. * [pybind] add pybind wrapper for HmmTopology and TransitionModel. * [pybind] add pybind wrapper for DecodableInterface. * [pybind] add pybind wrapper for LatticeFasterDecoder. * [pybind] add pybind wrapper for DecodeUtteranceLatticeFaster. * [pybind] add wrapper for ParseOptions. * [pybind] support ParseOptions::Register. * [pybind] show how to decode in Python with kaldi pybind. Now you can perform decoding in python with kaldi pybind using LatticeFastDecoder. * fix CI build error for pybind. * [pybind] add pybind wrapper for warp-ctc. Note that there is no dependecy on any other frameworks, e.g., PyTorch.
1 parent 4981f62 commit 638376b

File tree

93 files changed

+4679
-513
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+4679
-513
lines changed

egs/aishell/s10/chain/common.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
3+
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
44
# Apache 2.0
55

66
from datetime import datetime
@@ -81,16 +81,8 @@ def save_training_info(filename, model_path, current_epoch, learning_rate, objf,
8181
logging.info('write training info to {}'.format(filename))
8282

8383

84-
def read_mat(filename):
85-
ki = kaldi.Input(filename)
86-
m = kaldi.FloatMatrix()
87-
m.Read(ki.Stream(), binary=True, add=False)
88-
ki.Close()
89-
return m.numpy()
90-
91-
9284
def load_lda_mat(lda_mat_filename):
93-
lda_mat = read_mat(lda_mat_filename)
85+
lda_mat = kaldi.read_mat(lda_mat_filename).numpy()
9486
# y = Ax + b,
9587
# lda contains [A, b], x is feature
9688
# A.rows() == b.rows()

egs/aishell/s10/chain/egs_dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
3+
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
44
# Apache 2.0
55

66
import glob
@@ -39,8 +39,11 @@ def get_egs_dataloader(egs_dir, egs_left_context, egs_right_context):
3939

4040
def read_nnet_chain_example(rxfilename):
4141
eg = nnet3.NnetChainExample()
42-
ki = kaldi.Input(rxfilename=rxfilename)
43-
eg.Read(ki.Stream(), True)
42+
ki = kaldi.Input()
43+
is_opened, is_binary = ki.Open(rxfilename, read_header=True)
44+
if not is_opened:
45+
raise Exception('Failed to open {}'.format(rxfilename))
46+
eg.Read(ki.Stream(), is_binary)
4447
ki.Close()
4548
return eg
4649

src/hmm/hmm-topology.h

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,17 @@ namespace kaldi {
4343
// The Topology object can have multiple <TopologyEntry> blocks.
4444
// This is useful if there are multiple types of topology in the system.
4545
46-
<Topology>
47-
<TopologyEntry>
48-
<ForPhones> 1 2 3 4 5 6 7 8 </ForPhones>
49-
<State> 0 <PdfClass> 0
50-
<Transition> 0 0.5
51-
<Transition> 1 0.5
52-
</State>
53-
<State> 1 <PdfClass> 1
54-
<Transition> 1 0.5
55-
<Transition> 2 0.5
56-
</State>
57-
<State> 2 <PdfClass> 2
58-
<Transition> 2 0.5
59-
<Transition> 3 0.5
60-
<Final> 0.5
61-
</State>
62-
<State> 3
63-
</State>
64-
</TopologyEntry>
65-
</Topology>
46+
<Topology>
47+
<TopologyEntry>
48+
<ForPhones>
49+
1 2 3 4 5 6 7 8
50+
</ForPhones>
51+
<State> 0 <PdfClass> 0 <Transition> 0 0.5 <Transition> 1 0.5 </State>
52+
<State> 1 <PdfClass> 1 <Transition> 1 0.5 <Transition> 2 0.5 </State>
53+
<State> 2 <PdfClass> 2 <Transition> 2 0.5 <Transition> 3 0.5 </State>
54+
<State> 3 </State>
55+
</TopologyEntry>
56+
</Topology>
6657
*/
6758

6859
// kNoPdf is used where pdf_class or pdf would be used, to indicate,

src/pybind/Makefile

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ endif
3232

3333
CXXFLAGS += -O3 $(LTOFLAG) -I.
3434
LDFLAGS += $(LTOFLAG)
35+
LDFLAGS += -fuse-ld=gold
3536

3637
ifeq ($(shell uname),Darwin)
3738
LDFLAGS += -undefined dynamic_lookup
@@ -43,10 +44,15 @@ chain/chain_den_graph_pybind.cc \
4344
chain/chain_pybind.cc \
4445
chain/chain_supervision_pybind.cc \
4546
chain/chain_training_pybind.cc \
47+
ctc/ctc_pybind.cc \
4648
cudamatrix/cu_device_pybind.cc \
4749
cudamatrix/cu_matrix_pybind.cc \
4850
cudamatrix/cu_vector_pybind.cc \
4951
cudamatrix/cudamatrix_pybind.cc \
52+
decoder/decodable_matrix_pybind.cc \
53+
decoder/decoder_pybind.cc \
54+
decoder/decoder_wrappers_pybind.cc \
55+
decoder/lattice_faster_decoder_pybind.cc \
5056
dlpack/dlpack_deleter.cc \
5157
dlpack/dlpack_pybind.cc \
5258
dlpack/dlpack_submatrix.cc \
@@ -61,8 +67,20 @@ fst/fst_pybind.cc \
6167
fst/symbol_table_pybind.cc \
6268
fst/vector_fst_pybind.cc \
6369
fst/weight_pybind.cc \
70+
fstext/fstext_pybind.cc \
6471
fstext/kaldi_fst_io_pybind.cc \
72+
fstext/lattice_weight_pybind.cc \
73+
hmm/hmm_pybind.cc \
74+
hmm/hmm_topology_pybind.cc \
75+
hmm/transition_model_pybind.cc \
76+
itf/context_dep_itf_pybind.cc \
77+
itf/decodable_itf_pybind.cc \
78+
itf/itf_pybind.cc \
79+
itf/options_itf_pybind.cc \
6580
kaldi_pybind.cc \
81+
lat/determinize_lattice_pruned_pybind.cc \
82+
lat/kaldi_lattice_pybind.cc \
83+
lat/lat_pybind.cc \
6684
matrix/compressed_matrix_pybind.cc \
6785
matrix/kaldi_matrix_pybind.cc \
6886
matrix/kaldi_vector_pybind.cc \
@@ -74,7 +92,9 @@ nnet3/nnet_chain_example_pybind.cc \
7492
nnet3/nnet_common_pybind.cc \
7593
nnet3/nnet_example_pybind.cc \
7694
tests/test_dlpack_subvector.cc \
95+
util/kaldi_holder_pybind.cc \
7796
util/kaldi_io_pybind.cc \
97+
util/parse_options_pybind.cc \
7898
util/table_types_pybind.cc \
7999
util/util_pybind.cc
80100

@@ -87,8 +107,11 @@ ADDLIBS := \
87107
../base/kaldi-base.a \
88108
../chain/kaldi-chain.a \
89109
../cudamatrix/kaldi-cudamatrix.a \
110+
../decoder/kaldi-decoder.a \
90111
../feat/kaldi-feat.a \
91112
../fstext/kaldi-fstext.a \
113+
../hmm/kaldi-hmm.a \
114+
../lat/kaldi-lat.a \
92115
../matrix/kaldi-matrix.a \
93116
../nnet3/kaldi-nnet3.a \
94117
../util/kaldi-util.a
@@ -110,17 +133,27 @@ endif
110133
TEST_DIRS := \
111134
chain \
112135
cudamatrix \
136+
decoder \
113137
dlpack \
114138
feat \
115139
fst \
140+
fstext \
141+
hmm \
142+
lat \
116143
matrix \
117144
nnet3 \
118-
tests
145+
tests \
146+
util
119147

120148
.PHONY: all clean test
121149

122150
all: $(LIBFILE)
123151

152+
include ctc/ctc.mk
153+
154+
ctc/%.o: ctc/%.cc $(LIB_WARP_CTC)
155+
$(CXX) -c $(CXXFLAGS) -o $@ $<
156+
124157
%.o: %.cc
125158
$(CXX) -c $(CXXFLAGS) -o $@ $<
126159

@@ -135,7 +168,7 @@ clean:
135168
-rm -f .depend.mk
136169

137170
test: all
138-
for d in $(TEST_DIRS); do make -C $$d test; done
171+
for d in $(TEST_DIRS); do make -C $$d test || exit 1; done
139172

140173
# valgrind-python.supp is from http://svn.python.org/projects/python/trunk/Misc/valgrind-python.supp
141174
# since we do not compile Python from source, we follow the comment in valgrind-python.supp

src/pybind/ctc/Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
test:
3+
python3 ./ctc_pybind_test.py
4+
python3 ./ctc_pybind_test_gpu.py

src/pybind/ctc/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
This directory provides wrapper for
3+
warp-ctc from https://github.com/baidu-research/warp-ctc.
4+
5+
warp-ctc uses the same license as Kaldi, Apache License 2.0.
6+
7+
Although warp-ctc has not been updated for a long time, it
8+
is still widely used. For example, espnet is still using
9+
it (https://github.com/espnet/espnet/issues/1434).
10+
11+
When it comes to PyTorch, we may switch to its built-in
12+
`torch.nn.CTCLoss` (https://pytorch.org/docs/stable/nn.html#torch.nn.CTCLoss)
13+
if we find it faster than warp-ctc. This needs some benchmarks.
14+
15+
Note that this wrapper has no dependencies on PyTorch.

src/pybind/ctc/ctc.mk

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
WGET ?= wget
3+
4+
# This commit is the latest commit available as of 2019.01.15 .
5+
# Update it if needed.
6+
COMMIT_ID := bc29dcfff07ced1c7a19a4ecee48e5ad583cef8e
7+
8+
WARP_CTC_FILENAME := ctc/warp-ctc.tar.gz
9+
LIB_WARP_CTC := ctc/warp-ctc/build/libwarpctc.so
10+
11+
LDFLAGS += -Wl,-rpath=$(CURDIR)/ctc/warp-ctc/build
12+
EXTRA_LDLIBS += $(LIB_WARP_CTC)
13+
14+
WITH_OMP := ON
15+
16+
ifdef CI_TARGETS
17+
WITH_OMP := OFF
18+
endif
19+
20+
$(LIB_WARP_CTC): $(WARP_CTC_FILENAME)
21+
cd ctc/warp-ctc && \
22+
sed -i 's/--std=c++11/-std=c++11/g' CMakeLists.txt && \
23+
mkdir -p build && \
24+
cd build && \
25+
cmake -DBUILD_TESTS=OFF \
26+
-DWITH_OMP=$(WITH_OMP) \
27+
-DBUILD_SHARED=ON .. && \
28+
make -j2
29+
30+
$(WARP_CTC_FILENAME):
31+
cd ctc && \
32+
$(WGET) -O warp-ctc.tar.gz \
33+
--timeout=10 --tries=3 \
34+
https://github.com/baidu-research/warp-ctc/archive/$(COMMIT_ID).tar.gz && \
35+
tar xf warp-ctc.tar.gz && \
36+
ln -sf warp-ctc-$(COMMIT_ID) warp-ctc

0 commit comments

Comments
 (0)