Skip to content

Commit 8184005

Browse files
authored
[src] Don't leak new'd NnetChainModel2::LanguageInfo (#4562)
* Took heed from static analysis about an unpaired delete. * Remove an unused data member and a GetPathName overload. * Move the remaining used GetPathName() helper into the .cc file, as it is a static helper function which doesn't semantically belong in the class. * Adjust formatting to the coding style (not exhaustively).
1 parent fe94896 commit 8184005

File tree

2 files changed

+113
-154
lines changed

2 files changed

+113
-154
lines changed

src/nnet3/nnet-chain-training2.cc

Lines changed: 74 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,21 @@
2020
// limitations under the License.
2121

2222
#include "nnet3/nnet-chain-training2.h"
23+
2324
#include "nnet3/nnet-utils.h"
2425

2526
namespace kaldi {
2627
namespace nnet3 {
2728

2829
NnetChainTrainer2::NnetChainTrainer2(const NnetChainTraining2Options &opts,
29-
const NnetChainModel2 &model,
30-
Nnet *nnet):
31-
opts_(opts),
32-
model_(model),
33-
nnet_(nnet),
34-
compiler_(*nnet, opts_.nnet_config.optimize_config,
35-
opts_.nnet_config.compiler_config),
36-
num_minibatches_processed_(0),
37-
max_change_stats_(*nnet),
38-
srand_seed_(RandInt(0, 100000)) {
30+
const NnetChainModel2 &model,
31+
Nnet *nnet)
32+
: opts_(opts), model_(model), nnet_(nnet),
33+
compiler_(*nnet, opts_.nnet_config.optimize_config,
34+
opts_.nnet_config.compiler_config),
35+
num_minibatches_processed_(0),
36+
max_change_stats_(*nnet),
37+
srand_seed_(RandInt(0, 100000)) {
3938

4039
if (opts.nnet_config.zero_component_stats)
4140
ZeroComponentStats(nnet);
@@ -50,7 +49,8 @@ NnetChainTrainer2::NnetChainTrainer2(const NnetChainTraining2Options &opts,
5049
try {
5150
Input ki(opts.nnet_config.read_cache, &binary);
5251
compiler_.ReadCache(ki.Stream(), binary);
53-
KALDI_LOG << "Read computation cache from " << opts.nnet_config.read_cache;
52+
KALDI_LOG << "Read computation cache from "
53+
<< opts.nnet_config.read_cache;
5454
} catch (...) {
5555
KALDI_WARN << "Could not open cached computation. "
5656
"Probably this is the first training iteration.";
@@ -59,23 +59,26 @@ NnetChainTrainer2::NnetChainTrainer2(const NnetChainTraining2Options &opts,
5959
}
6060

6161

62-
void NnetChainTrainer2::Train(const std::string &key, NnetChainExample &chain_eg) {
62+
void NnetChainTrainer2::Train(const std::string &key,
63+
NnetChainExample &chain_eg) {
6364
bool need_model_derivative = true;
6465
const NnetTrainerOptions &nnet_config = opts_.nnet_config;
6566
bool use_xent_regularization = (opts_.chain_config.xent_regularize != 0.0);
6667
ComputationRequest request;
6768
std::string lang_name = "default";
6869
ParseFromQueryString(key, "lang", &lang_name);
69-
for (size_t i = 0; i < chain_eg.outputs.size(); i++) {
70-
// there will normally be exactly one output , named "output"
71-
if(chain_eg.outputs[i].name.compare("output")==0)
72-
chain_eg.outputs[i].name = "output-" + lang_name;
70+
for (size_t i = 0; i < chain_eg.outputs.size(); ++i) {
71+
// there will normally be exactly one output, named "output"
72+
if (chain_eg.outputs[i].name.compare("output") == 0) {
73+
chain_eg.outputs[i].name = "output-" + lang_name;
74+
}
7375
}
7476
GetChainComputationRequest(*nnet_, chain_eg, need_model_derivative,
7577
nnet_config.store_component_stats,
7678
use_xent_regularization, need_model_derivative,
7779
&request);
78-
std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
80+
std::shared_ptr<const NnetComputation> computation =
81+
compiler_.Compile(request);
7982

8083
if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_
8184
% nnet_config.backstitch_training_interval ==
@@ -103,9 +106,9 @@ void NnetChainTrainer2::Train(const std::string &key, NnetChainExample &chain_eg
103106
}
104107

105108
void NnetChainTrainer2::TrainInternal(const std::string &key,
106-
const NnetChainExample &eg,
107-
const NnetComputation &computation,
108-
const std::string &lang_name) {
109+
const NnetChainExample &eg,
110+
const NnetComputation &computation,
111+
const std::string &lang_name) {
109112
const NnetTrainerOptions &nnet_config = opts_.nnet_config;
110113
// note: because we give the 1st arg (nnet_) as a pointer to the
111114
// constructor of 'computer', it will use that copy of the nnet to
@@ -143,15 +146,16 @@ void NnetChainTrainer2::TrainInternal(const std::string &key,
143146
ConstrainOrthonormal(nnet_);
144147

145148
// Scale delta_nnet
146-
if (success)
149+
if (success) {
147150
ScaleNnet(nnet_config.momentum, delta_nnet_);
148-
else
151+
} else {
149152
ScaleNnet(0.0, delta_nnet_);
153+
}
150154
}
151155

152-
void NnetChainTrainer2::TrainInternalBackstitch(const std::string key, const NnetChainExample &eg,
153-
const NnetComputation &computation,
154-
bool is_backstitch_step1) {
156+
void NnetChainTrainer2::TrainInternalBackstitch(
157+
const std::string key, const NnetChainExample &eg,
158+
const NnetComputation &computation, bool is_backstitch_step1) {
155159
const NnetTrainerOptions &nnet_config = opts_.nnet_config;
156160
// note: because we give the 1st arg (nnet_) as a pointer to the
157161
// constructor of 'computer', it will use that copy of the nnet to
@@ -241,7 +245,8 @@ void NnetChainTrainer2::ProcessOutputs(bool is_backstitch_step2,
241245

242246
BaseFloat tot_objf, tot_l2_term, tot_weight;
243247

244-
ComputeChainObjfAndDeriv(opts_.chain_config, *(model_.GetDenGraphForLang(lang_name)),
248+
ComputeChainObjfAndDeriv(opts_.chain_config,
249+
*(model_.GetDenGraphForLang(lang_name)),
245250
sup.supervision, nnet_output,
246251
&tot_objf, &tot_l2_term, &tot_weight,
247252
&nnet_output_deriv,
@@ -263,8 +268,9 @@ void NnetChainTrainer2::ProcessOutputs(bool is_backstitch_step2,
263268
if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) {
264269
CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights);
265270
nnet_output_deriv.MulRowsVec(cu_deriv_weights);
266-
if (use_xent)
271+
if (use_xent) {
267272
xent_deriv.MulRowsVec(cu_deriv_weights);
273+
}
268274
}
269275

270276
/* computer->AcceptInput(sup.name, &nnet_output_deriv); */
@@ -284,103 +290,73 @@ void NnetChainTrainer2::ProcessOutputs(bool is_backstitch_step2,
284290
}
285291

286292
bool NnetChainTrainer2::PrintTotalStats() const {
287-
unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
288-
iter = objf_info_.begin(),
289-
end = objf_info_.end();
290293
bool ans = false;
291-
for (; iter != end; ++iter) {
292-
const std::string &name = iter->first;
293-
const ObjectiveFunctionInfo &info = iter->second;
294+
for (const auto &name_info : objf_info_) {
295+
const std::string &name = name_info.first;
296+
const ObjectiveFunctionInfo &info = name_info.second;
294297
ans = info.PrintTotalStats(name) || ans;
295298
}
296299
max_change_stats_.Print(*nnet_);
297300
return ans;
298301
}
299302

300303
NnetChainTrainer2::~NnetChainTrainer2() {
301-
if (opts_.nnet_config.write_cache != "") {
302-
Output ko(opts_.nnet_config.write_cache, opts_.nnet_config.binary_write_cache);
304+
if (!opts_.nnet_config.write_cache.empty()) {
305+
Output ko(opts_.nnet_config.write_cache,
306+
opts_.nnet_config.binary_write_cache);
303307
compiler_.WriteCache(ko.Stream(), opts_.nnet_config.binary_write_cache);
304308
KALDI_LOG << "Wrote computation cache to " << opts_.nnet_config.write_cache;
305309
}
306310
delete delta_nnet_;
307311
}
308312

309-
NnetChainModel2::NnetChainModel2(
310-
const NnetChainTraining2Options &opts,
311-
Nnet *nnet,
312-
const std::string &den_fst_dir
313-
):
314-
opts_(opts),
315-
nnet(nnet),
316-
den_fst_dir_(den_fst_dir) {
317-
}
318-
319-
NnetChainModel2::~NnetChainModel2() {
320-
}
313+
NnetChainModel2::NnetChainModel2(const NnetChainTraining2Options& /* unused */,
314+
Nnet *nnet, const std::string &den_fst_dir)
315+
: nnet(nnet), den_fst_dir_(den_fst_dir) {}
321316

322-
NnetChainModel2::LanguageInfo::LanguageInfo(
323-
const NnetChainModel2::LanguageInfo &other):
324-
name(other.name),
325-
den_graph(other.den_graph)
326-
{ }
317+
NnetChainModel2::~NnetChainModel2() {}
327318

319+
/* fst::StdVectorFst* NnetChainModel2::GetDenFstForLang( */
320+
/* const std::string &language_name) { */
321+
/* LanguageInfo *info = GetInfoForLang(language_name); */
322+
/* return &(info->den_fst); */
323+
/* } */
328324

329-
NnetChainModel2::LanguageInfo::LanguageInfo(
330-
const std::string &name,
331-
const fst::StdVectorFst &den_fst,
332-
int32 num_pdfs):
333-
name(name),
334-
den_graph(den_fst, num_pdfs){
325+
const chain::DenominatorGraph *NnetChainModel2::GetDenGraphForLang(
326+
const std::string &lang) {
327+
const LanguageInfo *info = GetInfoForLang(lang);
328+
KALDI_ASSERT(info != nullptr);
329+
return &(info->den_graph);
335330
}
336331

337-
void NnetChainModel2::GetPathname(const std::string &dir,
338-
const std::string &name,
339-
const std::string &suffix,
340-
std::string *pathname) {
341-
std::ostringstream str;
342-
str << dir << '/' << name << '.' << suffix;
343-
*pathname = str.str();
344-
}
332+
namespace {
345333

346-
void NnetChainModel2::GetPathname(const std::string &dir,
347-
const std::string &name,
348-
int32 job_id,
349-
const std::string &suffix,
350-
std::string *pathname) {
351-
std::ostringstream str;
352-
str << dir << '/' << name << '.' << job_id << '.' << suffix;
353-
*pathname = str.str();
334+
// Get a pathname in the form '<dir>/<name>.<suffix>'.
335+
std::string GetPathname(const std::string &dir,
336+
const std::string &name,
337+
const std::string &suffix) {
338+
return dir + '/' + name + '.' + suffix;
354339
}
355340

356-
NnetChainModel2::LanguageInfo *NnetChainModel2::GetInfoForLang(
341+
} // namespace
342+
343+
const NnetChainModel2::LanguageInfo *NnetChainModel2::GetInfoForLang(
357344
const std::string &lang) {
358-
auto iter = lang_info_.find(lang);
359-
if (iter != lang_info_.end()) {
360-
return iter->second;
361-
} else {
362-
std::string den_fst_filename;
363-
GetPathname(den_fst_dir_, lang, "den.fst", &den_fst_filename);
345+
// Using .count() as a Boolean "doesn't contain" is idiomatic. Fixed in C++20.
346+
if (!lang_info_.count(lang)) {
347+
std::string den_fst_filename = GetPathname(den_fst_dir_, lang, "den.fst");
364348
fst::StdVectorFst den_fst;
365349
ReadFstKaldi(den_fst_filename, &den_fst);
366-
std::string outputname = "output-" + lang;
367-
368-
LanguageInfo *info = new LanguageInfo(lang, den_fst, nnet->OutputDim(outputname));
369-
lang_info_[lang] = info;
370-
return info;
350+
std::string output = "output-" + lang;
351+
lang_info_.emplace(lang,
352+
LanguageInfo{lang, den_fst, nnet->OutputDim(output)});
371353
}
354+
// Use .at(), not operator[](), because the [] requires a default constructor.
355+
// The .at() throws if the element isn't found, which works in lieu of an
356+
// assertion here.
357+
return &lang_info_.at(lang);
372358
}
373359

374-
/* fst::StdVectorFst* NnetChainModel2::GetDenFstForLang( */
375-
/* const std::string &language_name) { */
376-
/* LanguageInfo *info = GetInfoForLang(language_name); */
377-
/* return &(info->den_fst); */
378-
/* } */
379-
380-
chain::DenominatorGraph *NnetChainModel2::GetDenGraphForLang(const std::string &language_name){
381-
LanguageInfo *info = GetInfoForLang(language_name);
382-
return &(info->den_graph);
383-
}
384-
} // namespace nnet3
385-
} // namespace kaldi
386360

361+
} // namespace nnet3
362+
} // namespace kaldi

0 commit comments

Comments
 (0)