Skip to content

Commit 51a2732

Browse files
naxingyuvimalmanoharferb2015csukuangfjjtrmal
authored andcommitted
Sync pybind11 with master (#3832)
* [scripts] allowed_durations computation in standalone script * [src] Make compute-gop work with missing alignments (#3830) * [pybind] Support to construct CuSubMatrix/CuSubVector from DLPack without GPU support. (#3828) * Update .travis config to build the pybind11 branch * disable lto for travis in pybind * [pybind] add pybind wrapper for int vector reader/writer. (#3833) Co-authored-by: Vimal Manohar <[email protected]> Co-authored-by: ferb2015 <[email protected]> Co-authored-by: Fangjun Kuang <[email protected]> Co-authored-by: Jan "yenda" Trmal <[email protected]>
1 parent 4c87771 commit 51a2732

File tree

5 files changed

+231
-3
lines changed

5 files changed

+231
-3
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ addons:
2929
branches:
3030
only:
3131
- master
32+
- pybind11
3233

3334
before_install:
3435
- cat /proc/sys/kernel/core_pattern
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2017 Hossein Hadian
4+
# 2019 Facebook Inc. (Author: Vimal Manohar)
5+
# Apache 2.0
6+
7+
8+
""" This script generates a set of allowed lengths of utterances
9+
spaced by a factor (like 10%). This is useful for generating
10+
fixed-length chunks for chain training.
11+
"""
12+
13+
import argparse
14+
import os
15+
import sys
16+
import copy
17+
import math
18+
import logging
19+
20+
sys.path.insert(0, 'steps')
21+
import libs.common as common_lib
22+
23+
logger = logging.getLogger('libs')
24+
logger.setLevel(logging.INFO)
25+
handler = logging.StreamHandler()
26+
handler.setLevel(logging.INFO)
27+
formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
28+
"%(funcName)s - %(levelname)s ] %(message)s")
29+
handler.setFormatter(formatter)
30+
logger.addHandler(handler)
31+
32+
def get_args():
33+
parser = argparse.ArgumentParser(description="""
34+
This script creates a list of allowed durations of utterances for flatstart
35+
LF-MMI training corresponding to input data directory 'data_dir' and writes
36+
it in two files in output directory 'dir':
37+
1) allowed_durs.txt -- durations are in seconds
38+
2) allowed_lengths.txt -- lengths are in number of frames
39+
40+
Both the allowed_durs.txt and allowed_lengths.txt are formatted to
41+
have one entry on each line. Examples are as follows:
42+
43+
$ echo data/train/allowed_lengths.txt
44+
414
45+
435
46+
468
47+
48+
$ echo data/train/allowed_durs.txt
49+
4.16
50+
4.37
51+
4.70
52+
53+
These files can then be used by a downstream script to perturb the
54+
utterances to these lengths.
55+
A perturbed data directory (created by a downstream script
56+
similar to utils/data/perturb_speed_to_allowed_lengths.py)
57+
that only contains utterances of these allowed durations,
58+
along with the corresponding allowed_lengths.txt are
59+
consumed by the e2e chain egs preparation script.
60+
See steps/nnet3/chain/e2e/get_egs_e2e.sh for how these are used.
61+
62+
See also:
63+
* egs/cifar/v1/image/get_allowed_lengths.py -- a similar script for OCR datasets
64+
* utils/data/perturb_speed_to_allowed_lengths.py --
65+
creates the allowed_lengths.txt AND perturbs the data directory
66+
""")
67+
parser.add_argument('factor', type=float, default=12,
68+
help='Spacing (in percentage) between allowed lengths. '
69+
'Can be 0, which means all seen lengths that are a multiple of '
70+
'frame_subsampling_factor will be allowed.')
71+
parser.add_argument('data_dir', type=str, help='path to data dir. Assumes that '
72+
'it contains the utt2dur file.')
73+
parser.add_argument('dir', type=str, help='We write the output files '
74+
'allowed_lengths.txt and allowed_durs.txt to this directory.')
75+
parser.add_argument('--coverage-factor', type=float, default=0.05,
76+
help="""Percentage of durations not covered from each
77+
side of duration histogram.""")
78+
parser.add_argument('--frame-shift', type=int, default=10,
79+
help="""Frame shift in milliseconds.""")
80+
parser.add_argument('--frame-length', type=int, default=25,
81+
help="""Frame length in milliseconds.""")
82+
parser.add_argument('--frame-subsampling-factor', type=int, default=3,
83+
help="""Chain frame subsampling factor.
84+
See steps/nnet3/chain/train.py""")
85+
args = parser.parse_args()
86+
return args
87+
88+
89+
def read_kaldi_mapfile(path):
90+
""" Read any Kaldi mapping file - like text, .scp files, etc.
91+
"""
92+
93+
m = {}
94+
with open(path, 'r', encoding='latin-1') as f:
95+
for line in f:
96+
line = line.strip(" \t\r\n")
97+
sp_pos = line.find(' ')
98+
key = line[:sp_pos]
99+
val = line[sp_pos+1:]
100+
m[key] = val
101+
return m
102+
103+
104+
def find_duration_range(utt2dur, coverage_factor):
105+
"""Given a list of utterance durations, find the start and end duration to cover
106+
107+
If we try to cover
108+
all durations which occur in the training set, the number of
109+
allowed lengths could become very large.
110+
111+
Returns
112+
-------
113+
start_dur: float
114+
end_dur: float
115+
"""
116+
durs = [float(val) for key, val in utt2dur.items()]
117+
durs.sort()
118+
to_ignore_dur = 0
119+
tot_dur = sum(durs)
120+
for d in durs:
121+
to_ignore_dur += d
122+
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
123+
start_dur = d
124+
break
125+
to_ignore_dur = 0
126+
for d in reversed(durs):
127+
to_ignore_dur += d
128+
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
129+
end_dur = d
130+
break
131+
if start_dur < 0.3:
132+
start_dur = 0.3 # a hard limit to avoid too many allowed lengths --not critical
133+
return start_dur, end_dur
134+
135+
136+
def get_allowed_durations(start_dur, end_dur, args):
137+
"""Given the start and end duration, find a set of
138+
allowed durations spaced by args.factor%. Also write
139+
out the list of allowed durations and the corresponding
140+
allowed lengths (in frames) on disk.
141+
142+
Returns
143+
-------
144+
allowed_durations: list of allowed durations (in seconds)
145+
"""
146+
147+
allowed_durations = []
148+
d = start_dur
149+
with open(os.path.join(args.dir, 'allowed_durs.txt'), 'w', encoding='latin-1') as durs_fp, \
150+
open(os.path.join(args.dir, 'allowed_lengths.txt'), 'w', encoding='latin-1') as lengths_fp:
151+
while d < end_dur:
152+
length = int(d * 1000 - args.frame_length) / args.frame_shift + 1
153+
if length % args.frame_subsampling_factor != 0:
154+
length = (args.frame_subsampling_factor *
155+
(length // args.frame_subsampling_factor))
156+
d = (args.frame_shift * (length - 1.0)
157+
+ args.frame_length + args.frame_shift / 2) / 1000.0
158+
allowed_durations.append(d)
159+
durs_fp.write("{}\n".format(d))
160+
lengths_fp.write("{}\n".format(int(length)))
161+
d *= args.factor
162+
return allowed_durations
163+
164+
165+
def get_trivial_allowed_durations(utt2dur, args):
166+
lengths = list(set(
167+
[int(float(d) * 1000 - args.frame_length) / args.frame_shift + 1
168+
for key, d in utt2dur.items()]
169+
))
170+
lengths.sort()
171+
172+
allowed_durations = []
173+
with open(os.path.join(args.dir, 'allowed_durs.txt'), 'w', encoding='latin-1') as durs_fp, \
174+
open(os.path.join(args.dir, 'allowed_lengths.txt'), 'w', encoding='latin-1') as lengths_fp:
175+
for length in lengths:
176+
if length % args.frame_subsampling_factor != 0:
177+
length = (args.frame_subsampling_factor *
178+
(length // args.frame_subsampling_factor))
179+
d = (args.frame_shift * (length - 1.0)
180+
+ args.frame_length + args.frame_shift / 2) / 1000.0
181+
allowed_durations.append(d)
182+
durs_fp.write("{}\n".format(d))
183+
lengths_fp.write("{}\n".format(length))
184+
185+
assert len(allowed_durations) > 0
186+
start_dur = allowed_durations[0]
187+
end_dur = allowed_durations[-1]
188+
189+
logger.info("Durations in the range [{},{}] will be covered."
190+
"".format(start_dur, end_dur))
191+
logger.info("There will be {} unique allowed lengths "
192+
"for the utterances.".format(len(allowed_durations)))
193+
194+
return allowed_durations
195+
196+
197+
def main():
198+
args = get_args()
199+
utt2dur = read_kaldi_mapfile(os.path.join(args.data_dir, 'utt2dur'))
200+
201+
if args.factor == 0.0:
202+
get_trivial_allowed_durations(utt2dur, args)
203+
return
204+
205+
args.factor = 1.0 + args.factor / 100.0
206+
207+
start_dur, end_dur = find_duration_range(utt2dur, args.coverage_factor)
208+
logger.info("Durations in the range [{},{}] will be covered. "
209+
"Coverage rate: {}%".format(start_dur, end_dur,
210+
100.0 - args.coverage_factor * 2))
211+
logger.info("There will be {} unique allowed lengths "
212+
"for the utterances.".format(int(math.log(end_dur / start_dur)/
213+
math.log(args.factor))))
214+
215+
get_allowed_durations(start_dur, end_dur, args)
216+
217+
218+
if __name__ == '__main__':
219+
main()

src/bin/compute-gop.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ int main(int argc, char *argv[]) {
155155
int32 num_done = 0;
156156
for (; !prob_reader.Done(); prob_reader.Next()) {
157157
std::string key = prob_reader.Key();
158+
if (!alignment_reader.HasKey(key)) {
159+
KALDI_WARN << "No alignment for utterance " << key;
160+
continue;
161+
}
158162
auto alignment = alignment_reader.Value(key);
159163
Matrix<BaseFloat> &probs = prob_reader.Value();
160164
if (log_applied) probs.ApplyExp();

src/pybind/Makefile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@ LIBFILE_EXTENSION := $(PYBIND_EXTENSION)
2626

2727
# pybind11 is heavily templated and generates code that is bloated before optimization.
2828
# -flto is link time optimization which apparently is important.
29-
CXXFLAGS += -O3 -flto -I.
30-
LDFLAGS += -flto
29+
ifndef CI_TARGETS
30+
LTOFLAG = -flto
31+
endif
32+
33+
CXXFLAGS += -O3 $(LTOFLAG) -I.
34+
LDFLAGS += $(LTOFLAG)
3135

3236
ifeq ($(shell uname),Darwin)
3337
LDFLAGS += -undefined dynamic_lookup

tools/extras/travis_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ CCC=$(mtoken CXX "$CXX")
6868
echo "Building tools..." [Time: $(date)]
6969
runvx cd tools
7070
runvx make -j$MAXPAR openfst "$CCC" CXXFLAGS="$CF" \
71-
OPENFST_CONFIGURE="--disable-static --enable-shared --disable-bin --disable-dependency-tracking"
71+
OPENFST_CONFIGURE="--disable-static --enable-shared --disable-dependency-tracking"
7272
runvx make -j$MAXPAR cub "$CCC" CXXFLAGS="$CF"
7373
cd ..
7474

0 commit comments

Comments
 (0)