2020// limitations under the License.
2121
2222#include " nnet3/nnet-chain-training2.h"
23+
2324#include " nnet3/nnet-utils.h"
2425
2526namespace kaldi {
2627namespace nnet3 {
2728
2829NnetChainTrainer2::NnetChainTrainer2 (const NnetChainTraining2Options &opts,
29- const NnetChainModel2 &model,
30- Nnet *nnet):
31- opts_ (opts),
32- model_ (model),
33- nnet_ (nnet),
34- compiler_ (*nnet, opts_.nnet_config.optimize_config,
35- opts_.nnet_config.compiler_config),
36- num_minibatches_processed_ (0 ),
37- max_change_stats_ (*nnet),
38- srand_seed_ (RandInt(0 , 100000 )) {
30+ const NnetChainModel2 &model,
31+ Nnet *nnet)
32+ : opts_(opts), model_(model), nnet_(nnet),
33+ compiler_ (*nnet, opts_.nnet_config.optimize_config,
34+ opts_.nnet_config.compiler_config),
35+ num_minibatches_processed_(0 ),
36+ max_change_stats_(*nnet),
37+ srand_seed_(RandInt(0 , 100000 )) {
3938
4039 if (opts.nnet_config .zero_component_stats )
4140 ZeroComponentStats (nnet);
@@ -50,7 +49,8 @@ NnetChainTrainer2::NnetChainTrainer2(const NnetChainTraining2Options &opts,
5049 try {
5150 Input ki (opts.nnet_config .read_cache , &binary);
5251 compiler_.ReadCache (ki.Stream (), binary);
53- KALDI_LOG << " Read computation cache from " << opts.nnet_config .read_cache ;
52+ KALDI_LOG << " Read computation cache from "
53+ << opts.nnet_config .read_cache ;
5454 } catch (...) {
5555 KALDI_WARN << " Could not open cached computation. "
5656 " Probably this is the first training iteration." ;
@@ -59,23 +59,26 @@ NnetChainTrainer2::NnetChainTrainer2(const NnetChainTraining2Options &opts,
5959}
6060
6161
62- void NnetChainTrainer2::Train (const std::string &key, NnetChainExample &chain_eg) {
62+ void NnetChainTrainer2::Train (const std::string &key,
63+ NnetChainExample &chain_eg) {
6364 bool need_model_derivative = true ;
6465 const NnetTrainerOptions &nnet_config = opts_.nnet_config ;
6566 bool use_xent_regularization = (opts_.chain_config .xent_regularize != 0.0 );
6667 ComputationRequest request;
6768 std::string lang_name = " default" ;
6869 ParseFromQueryString (key, " lang" , &lang_name);
69- for (size_t i = 0 ; i < chain_eg.outputs .size (); i++) {
70- // there will normally be exactly one output , named "output"
71- if (chain_eg.outputs [i].name .compare (" output" )==0 )
72- chain_eg.outputs [i].name = " output-" + lang_name;
70+ for (size_t i = 0 ; i < chain_eg.outputs .size (); ++i) {
71+ // there will normally be exactly one output, named "output"
72+ if (chain_eg.outputs [i].name .compare (" output" ) == 0 ) {
73+ chain_eg.outputs [i].name = " output-" + lang_name;
74+ }
7375 }
7476 GetChainComputationRequest (*nnet_, chain_eg, need_model_derivative,
7577 nnet_config.store_component_stats ,
7678 use_xent_regularization, need_model_derivative,
7779 &request);
78- std::shared_ptr<const NnetComputation> computation = compiler_.Compile (request);
80+ std::shared_ptr<const NnetComputation> computation =
81+ compiler_.Compile (request);
7982
8083 if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_
8184 % nnet_config.backstitch_training_interval ==
@@ -103,9 +106,9 @@ void NnetChainTrainer2::Train(const std::string &key, NnetChainExample &chain_eg
103106}
104107
105108void NnetChainTrainer2::TrainInternal (const std::string &key,
106- const NnetChainExample &eg,
107- const NnetComputation &computation,
108- const std::string &lang_name) {
109+ const NnetChainExample &eg,
110+ const NnetComputation &computation,
111+ const std::string &lang_name) {
109112 const NnetTrainerOptions &nnet_config = opts_.nnet_config ;
110113 // note: because we give the 1st arg (nnet_) as a pointer to the
111114 // constructor of 'computer', it will use that copy of the nnet to
@@ -143,15 +146,16 @@ void NnetChainTrainer2::TrainInternal(const std::string &key,
143146 ConstrainOrthonormal (nnet_);
144147
145148 // Scale delta_nnet
146- if (success)
149+ if (success) {
147150 ScaleNnet (nnet_config.momentum , delta_nnet_);
148- else
151+ } else {
149152 ScaleNnet (0.0 , delta_nnet_);
153+ }
150154}
151155
152- void NnetChainTrainer2::TrainInternalBackstitch (const std::string key, const NnetChainExample &eg,
153- const NnetComputation &computation ,
154- bool is_backstitch_step1) {
156+ void NnetChainTrainer2::TrainInternalBackstitch (
157+ const std::string key, const NnetChainExample &eg ,
158+ const NnetComputation &computation, bool is_backstitch_step1) {
155159 const NnetTrainerOptions &nnet_config = opts_.nnet_config ;
156160 // note: because we give the 1st arg (nnet_) as a pointer to the
157161 // constructor of 'computer', it will use that copy of the nnet to
@@ -241,7 +245,8 @@ void NnetChainTrainer2::ProcessOutputs(bool is_backstitch_step2,
241245
242246 BaseFloat tot_objf, tot_l2_term, tot_weight;
243247
244- ComputeChainObjfAndDeriv (opts_.chain_config , *(model_.GetDenGraphForLang (lang_name)),
248+ ComputeChainObjfAndDeriv (opts_.chain_config ,
249+ *(model_.GetDenGraphForLang (lang_name)),
245250 sup.supervision , nnet_output,
246251 &tot_objf, &tot_l2_term, &tot_weight,
247252 &nnet_output_deriv,
@@ -263,8 +268,9 @@ void NnetChainTrainer2::ProcessOutputs(bool is_backstitch_step2,
263268 if (opts_.apply_deriv_weights && sup.deriv_weights .Dim () != 0 ) {
264269 CuVector<BaseFloat> cu_deriv_weights (sup.deriv_weights );
265270 nnet_output_deriv.MulRowsVec (cu_deriv_weights);
266- if (use_xent)
271+ if (use_xent) {
267272 xent_deriv.MulRowsVec (cu_deriv_weights);
273+ }
268274 }
269275
270276 /* computer->AcceptInput(sup.name, &nnet_output_deriv); */
@@ -284,103 +290,73 @@ void NnetChainTrainer2::ProcessOutputs(bool is_backstitch_step2,
284290}
285291
286292bool NnetChainTrainer2::PrintTotalStats () const {
287- unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
288- iter = objf_info_.begin (),
289- end = objf_info_.end ();
290293 bool ans = false ;
291- for (; iter != end; ++iter ) {
292- const std::string &name = iter-> first ;
293- const ObjectiveFunctionInfo &info = iter-> second ;
294+ for (const auto &name_info : objf_info_ ) {
295+ const std::string &name = name_info. first ;
296+ const ObjectiveFunctionInfo &info = name_info. second ;
294297 ans = info.PrintTotalStats (name) || ans;
295298 }
296299 max_change_stats_.Print (*nnet_);
297300 return ans;
298301}
299302
300303NnetChainTrainer2::~NnetChainTrainer2 () {
301- if (opts_.nnet_config .write_cache != " " ) {
302- Output ko (opts_.nnet_config .write_cache , opts_.nnet_config .binary_write_cache );
304+ if (!opts_.nnet_config .write_cache .empty ()) {
305+ Output ko (opts_.nnet_config .write_cache ,
306+ opts_.nnet_config .binary_write_cache );
303307 compiler_.WriteCache (ko.Stream (), opts_.nnet_config .binary_write_cache );
304308 KALDI_LOG << " Wrote computation cache to " << opts_.nnet_config .write_cache ;
305309 }
306310 delete delta_nnet_;
307311}
308312
309- NnetChainModel2::NnetChainModel2 (
310- const NnetChainTraining2Options &opts,
311- Nnet *nnet,
312- const std::string &den_fst_dir
313- ):
314- opts_ (opts),
315- nnet (nnet),
316- den_fst_dir_ (den_fst_dir) {
317- }
318-
319- NnetChainModel2::~NnetChainModel2 () {
320- }
313+ NnetChainModel2::NnetChainModel2 (const NnetChainTraining2Options& /* unused */ ,
314+ Nnet *nnet, const std::string &den_fst_dir)
315+ : nnet(nnet), den_fst_dir_(den_fst_dir) {}
321316
322- NnetChainModel2::LanguageInfo::LanguageInfo (
323- const NnetChainModel2::LanguageInfo &other):
324- name (other.name),
325- den_graph (other.den_graph)
326- { }
317+ NnetChainModel2::~NnetChainModel2 () {}
327318
319+ /* fst::StdVectorFst* NnetChainModel2::GetDenFstForLang( */
320+ /* const std::string &language_name) { */
321+ /* LanguageInfo *info = GetInfoForLang(language_name); */
322+ /* return &(info->den_fst); */
323+ /* } */
328324
329- NnetChainModel2::LanguageInfo::LanguageInfo (
330- const std::string &name,
331- const fst::StdVectorFst &den_fst,
332- int32 num_pdfs):
333- name (name),
334- den_graph (den_fst, num_pdfs){
325+ const chain::DenominatorGraph *NnetChainModel2::GetDenGraphForLang (
326+ const std::string &lang) {
327+ const LanguageInfo *info = GetInfoForLang (lang);
328+ KALDI_ASSERT (info != nullptr );
329+ return &(info->den_graph );
335330}
336331
337- void NnetChainModel2::GetPathname (const std::string &dir,
338- const std::string &name,
339- const std::string &suffix,
340- std::string *pathname) {
341- std::ostringstream str;
342- str << dir << ' /' << name << ' .' << suffix;
343- *pathname = str.str ();
344- }
332+ namespace {
345333
346- void NnetChainModel2::GetPathname (const std::string &dir,
347- const std::string &name,
348- int32 job_id,
349- const std::string &suffix,
350- std::string *pathname) {
351- std::ostringstream str;
352- str << dir << ' /' << name << ' .' << job_id << ' .' << suffix;
353- *pathname = str.str ();
334+ // Get a pathname in the form '<dir>/<name>.<suffix>'.
335+ std::string GetPathname (const std::string &dir,
336+ const std::string &name,
337+ const std::string &suffix) {
338+ return dir + ' /' + name + ' .' + suffix;
354339}
355340
356- NnetChainModel2::LanguageInfo *NnetChainModel2::GetInfoForLang (
341+ } // namespace
342+
343+ const NnetChainModel2::LanguageInfo *NnetChainModel2::GetInfoForLang (
357344 const std::string &lang) {
358- auto iter = lang_info_.find (lang);
359- if (iter != lang_info_.end ()) {
360- return iter->second ;
361- } else {
362- std::string den_fst_filename;
363- GetPathname (den_fst_dir_, lang, " den.fst" , &den_fst_filename);
345+ // Using .count() as a Boolean "doesn't contain" is idiomatic. Fixed in C++20.
346+ if (!lang_info_.count (lang)) {
347+ std::string den_fst_filename = GetPathname (den_fst_dir_, lang, " den.fst" );
364348 fst::StdVectorFst den_fst;
365349 ReadFstKaldi (den_fst_filename, &den_fst);
366- std::string outputname = " output-" + lang;
367-
368- LanguageInfo *info = new LanguageInfo (lang, den_fst, nnet->OutputDim (outputname));
369- lang_info_[lang] = info;
370- return info;
350+ std::string output = " output-" + lang;
351+ lang_info_.emplace (lang,
352+ LanguageInfo{lang, den_fst, nnet->OutputDim (output)});
371353 }
354+ // Use .at(), not operator[](), because the [] requires a default constructor.
355+ // The .at() throws if the element isn't found, which works in lieu of an
356+ // assertion here.
357+ return &lang_info_.at (lang);
372358}
373359
374- /* fst::StdVectorFst* NnetChainModel2::GetDenFstForLang( */
375- /* const std::string &language_name) { */
376- /* LanguageInfo *info = GetInfoForLang(language_name); */
377- /* return &(info->den_fst); */
378- /* } */
379-
380- chain::DenominatorGraph *NnetChainModel2::GetDenGraphForLang (const std::string &language_name){
381- LanguageInfo *info = GetInfoForLang (language_name);
382- return &(info->den_graph );
383- }
384- } // namespace nnet3
385- } // namespace kaldi
386360
361+ } // namespace nnet3
362+ } // namespace kaldi
0 commit comments