Skip to content

Commit 4b30697

Browse files
committed
add test binary
1 parent 5503ba1 commit 4b30697

File tree

5 files changed

+392
-17
lines changed

5 files changed

+392
-17
lines changed

src/bin/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \
2222
matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \
2323
vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \
2424
transform-vec align-text matrix-dim post-to-smat compile-graph \
25-
compare-int-vector
25+
compare-int-vector latgen-faster-mapped-combine
2626

2727

2828
OBJFILES =

src/decoder/decoder-wrappers.cc

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,4 +546,303 @@ void AlignUtteranceWrapper(
546546
}
547547
}
548548

549+
// For lattice-faster-decoder-combine
550+
DecodeUtteranceLatticeFasterCombineClass::DecodeUtteranceLatticeFasterCombineClass(
551+
LatticeFasterDecoderCombine *decoder,
552+
DecodableInterface *decodable,
553+
const TransitionModel &trans_model,
554+
const fst::SymbolTable *word_syms,
555+
std::string utt,
556+
BaseFloat acoustic_scale,
557+
bool determinize,
558+
bool allow_partial,
559+
Int32VectorWriter *alignments_writer,
560+
Int32VectorWriter *words_writer,
561+
CompactLatticeWriter *compact_lattice_writer,
562+
LatticeWriter *lattice_writer,
563+
double *like_sum, // on success, adds likelihood to this.
564+
int64 *frame_sum, // on success, adds #frames to this.
565+
int32 *num_done, // on success (including partial decode), increments this.
566+
int32 *num_err, // on failure, increments this.
567+
int32 *num_partial): // If partial decode (final-state not reached), increments this.
568+
decoder_(decoder), decodable_(decodable), trans_model_(&trans_model),
569+
word_syms_(word_syms), utt_(utt), acoustic_scale_(acoustic_scale),
570+
determinize_(determinize), allow_partial_(allow_partial),
571+
alignments_writer_(alignments_writer),
572+
words_writer_(words_writer),
573+
compact_lattice_writer_(compact_lattice_writer),
574+
lattice_writer_(lattice_writer),
575+
like_sum_(like_sum), frame_sum_(frame_sum),
576+
num_done_(num_done), num_err_(num_err),
577+
num_partial_(num_partial),
578+
computed_(false), success_(false), partial_(false),
579+
clat_(NULL), lat_(NULL) { }
580+
581+
582+
void DecodeUtteranceLatticeFasterCombineClass::operator () () {
583+
// Decoding and lattice determinization happens here.
584+
computed_ = true; // Just means this function was called-- a check on the
585+
// calling code.
586+
success_ = true;
587+
using fst::VectorFst;
588+
if (!decoder_->Decode(decodable_)) {
589+
KALDI_WARN << "Failed to decode file " << utt_;
590+
success_ = false;
591+
}
592+
if (!decoder_->ReachedFinal()) {
593+
if (allow_partial_) {
594+
KALDI_WARN << "Outputting partial output for utterance " << utt_
595+
<< " since no final-state reached\n";
596+
partial_ = true;
597+
} else {
598+
KALDI_WARN << "Not producing output for utterance " << utt_
599+
<< " since no final-state reached and "
600+
<< "--allow-partial=false.\n";
601+
success_ = false;
602+
}
603+
}
604+
if (!success_) return;
605+
606+
// Get lattice, and do determinization if requested.
607+
lat_ = new Lattice;
608+
decoder_->GetRawLattice(lat_);
609+
if (lat_->NumStates() == 0)
610+
KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt_;
611+
fst::Connect(lat_);
612+
if (determinize_) {
613+
clat_ = new CompactLattice;
614+
if (!DeterminizeLatticePhonePrunedWrapper(
615+
*trans_model_,
616+
lat_,
617+
decoder_->GetOptions().lattice_beam,
618+
clat_,
619+
decoder_->GetOptions().det_opts))
620+
KALDI_WARN << "Determinization finished earlier than the beam for "
621+
<< "utterance " << utt_;
622+
delete lat_;
623+
lat_ = NULL;
624+
// We'll write the lattice without acoustic scaling.
625+
if (acoustic_scale_ != 0.0)
626+
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), clat_);
627+
} else {
628+
// We'll write the lattice without acoustic scaling.
629+
if (acoustic_scale_ != 0.0)
630+
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale_), lat_);
631+
}
632+
}
633+
634+
DecodeUtteranceLatticeFasterCombineClass::~DecodeUtteranceLatticeFasterCombineClass() {
635+
if (!computed_)
636+
KALDI_ERR << "Destructor called without operator (), error in calling code.";
637+
638+
if (!success_) {
639+
if (num_err_ != NULL) (*num_err_)++;
640+
} else { // successful decode.
641+
// Getting the one-best output is lightweight enough that we can do it in
642+
// the destructor (easier than adding more variables to the class, and
643+
// will rarely slow down the main thread.)
644+
double likelihood;
645+
LatticeWeight weight;
646+
int32 num_frames;
647+
{ // First do some stuff with word-level traceback...
648+
// This is basically for diagnostics.
649+
fst::VectorFst<LatticeArc> decoded;
650+
decoder_->GetBestPath(&decoded);
651+
if (decoded.NumStates() == 0) {
652+
// Shouldn't really reach this point as already checked success.
653+
KALDI_ERR << "Failed to get traceback for utterance " << utt_;
654+
}
655+
std::vector<int32> alignment;
656+
std::vector<int32> words;
657+
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
658+
num_frames = alignment.size();
659+
if (words_writer_->IsOpen())
660+
words_writer_->Write(utt_, words);
661+
if (alignments_writer_->IsOpen())
662+
alignments_writer_->Write(utt_, alignment);
663+
if (word_syms_ != NULL) {
664+
std::cerr << utt_ << ' ';
665+
for (size_t i = 0; i < words.size(); i++) {
666+
std::string s = word_syms_->Find(words[i]);
667+
if (s == "")
668+
KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
669+
std::cerr << s << ' ';
670+
}
671+
std::cerr << '\n';
672+
}
673+
likelihood = -(weight.Value1() + weight.Value2());
674+
}
675+
676+
// Ouptut the lattices.
677+
if (determinize_) { // CompactLattice output.
678+
KALDI_ASSERT(compact_lattice_writer_ != NULL && clat_ != NULL);
679+
if (clat_->NumStates() == 0) {
680+
KALDI_WARN << "Empty lattice for utterance " << utt_;
681+
} else {
682+
compact_lattice_writer_->Write(utt_, *clat_);
683+
}
684+
delete clat_;
685+
clat_ = NULL;
686+
} else {
687+
KALDI_ASSERT(lattice_writer_ != NULL && lat_ != NULL);
688+
if (lat_->NumStates() == 0) {
689+
KALDI_WARN << "Empty lattice for utterance " << utt_;
690+
} else {
691+
lattice_writer_->Write(utt_, *lat_);
692+
}
693+
delete lat_;
694+
lat_ = NULL;
695+
}
696+
697+
// Print out logging information.
698+
KALDI_LOG << "Log-like per frame for utterance " << utt_ << " is "
699+
<< (likelihood / num_frames) << " over "
700+
<< num_frames << " frames.";
701+
KALDI_VLOG(2) << "Cost for utterance " << utt_ << " is "
702+
<< weight.Value1() << " + " << weight.Value2();
703+
704+
// Now output the various diagnostic variables.
705+
if (like_sum_ != NULL) *like_sum_ += likelihood;
706+
if (frame_sum_ != NULL) *frame_sum_ += num_frames;
707+
if (num_done_ != NULL) (*num_done_)++;
708+
if (partial_ && num_partial_ != NULL) (*num_partial_)++;
709+
}
710+
// We were given ownership of these two objects that were passed in in
711+
// the initializer.
712+
delete decoder_;
713+
delete decodable_;
714+
}
715+
716+
717+
// Takes care of output. Returns true on success.
718+
template <typename FST>
719+
bool DecodeUtteranceLatticeFasterCombine(
720+
LatticeFasterDecoderCombineTpl<FST> &decoder, // not const but is really an input.
721+
DecodableInterface &decodable, // not const but is really an input.
722+
const TransitionModel &trans_model,
723+
const fst::SymbolTable *word_syms,
724+
std::string utt,
725+
double acoustic_scale,
726+
bool determinize,
727+
bool allow_partial,
728+
Int32VectorWriter *alignment_writer,
729+
Int32VectorWriter *words_writer,
730+
CompactLatticeWriter *compact_lattice_writer,
731+
LatticeWriter *lattice_writer,
732+
double *like_ptr) { // puts utterance's like in like_ptr on success.
733+
using fst::VectorFst;
734+
735+
if (!decoder.Decode(&decodable)) {
736+
KALDI_WARN << "Failed to decode file " << utt;
737+
return false;
738+
}
739+
if (!decoder.ReachedFinal()) {
740+
if (allow_partial) {
741+
KALDI_WARN << "Outputting partial output for utterance " << utt
742+
<< " since no final-state reached\n";
743+
} else {
744+
KALDI_WARN << "Not producing output for utterance " << utt
745+
<< " since no final-state reached and "
746+
<< "--allow-partial=false.\n";
747+
return false;
748+
}
749+
}
750+
751+
double likelihood;
752+
LatticeWeight weight;
753+
int32 num_frames;
754+
{ // First do some stuff with word-level traceback...
755+
VectorFst<LatticeArc> decoded;
756+
if (!decoder.GetBestPath(&decoded))
757+
// Shouldn't really reach this point as already checked success.
758+
KALDI_ERR << "Failed to get traceback for utterance " << utt;
759+
760+
std::vector<int32> alignment;
761+
std::vector<int32> words;
762+
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
763+
num_frames = alignment.size();
764+
if (words_writer->IsOpen())
765+
words_writer->Write(utt, words);
766+
if (alignment_writer->IsOpen())
767+
alignment_writer->Write(utt, alignment);
768+
if (word_syms != NULL) {
769+
std::cerr << utt << ' ';
770+
for (size_t i = 0; i < words.size(); i++) {
771+
std::string s = word_syms->Find(words[i]);
772+
if (s == "")
773+
KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
774+
std::cerr << s << ' ';
775+
}
776+
std::cerr << '\n';
777+
}
778+
likelihood = -(weight.Value1() + weight.Value2());
779+
}
780+
781+
// Get lattice, and do determinization if requested.
782+
Lattice lat;
783+
decoder.GetRawLattice(&lat);
784+
if (lat.NumStates() == 0)
785+
KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;
786+
fst::Connect(&lat);
787+
if (determinize) {
788+
CompactLattice clat;
789+
if (!DeterminizeLatticePhonePrunedWrapper(
790+
trans_model,
791+
&lat,
792+
decoder.GetOptions().lattice_beam,
793+
&clat,
794+
decoder.GetOptions().det_opts))
795+
KALDI_WARN << "Determinization finished earlier than the beam for "
796+
<< "utterance " << utt;
797+
// We'll write the lattice without acoustic scaling.
798+
if (acoustic_scale != 0.0)
799+
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat);
800+
compact_lattice_writer->Write(utt, clat);
801+
} else {
802+
// We'll write the lattice without acoustic scaling.
803+
if (acoustic_scale != 0.0)
804+
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat);
805+
lattice_writer->Write(utt, lat);
806+
}
807+
KALDI_LOG << "Log-like per frame for utterance " << utt << " is "
808+
<< (likelihood / num_frames) << " over "
809+
<< num_frames << " frames.";
810+
KALDI_VLOG(2) << "Cost for utterance " << utt << " is "
811+
<< weight.Value1() << " + " << weight.Value2();
812+
*like_ptr = likelihood;
813+
return true;
814+
}
815+
816+
// Instantiate the template above for the two required FST types.
817+
template bool DecodeUtteranceLatticeFasterCombine(
818+
LatticeFasterDecoderCombineTpl<fst::Fst<fst::StdArc> > &decoder,
819+
DecodableInterface &decodable,
820+
const TransitionModel &trans_model,
821+
const fst::SymbolTable *word_syms,
822+
std::string utt,
823+
double acoustic_scale,
824+
bool determinize,
825+
bool allow_partial,
826+
Int32VectorWriter *alignment_writer,
827+
Int32VectorWriter *words_writer,
828+
CompactLatticeWriter *compact_lattice_writer,
829+
LatticeWriter *lattice_writer,
830+
double *like_ptr);
831+
832+
template bool DecodeUtteranceLatticeFasterCombine(
833+
LatticeFasterDecoderCombineTpl<fst::GrammarFst> &decoder,
834+
DecodableInterface &decodable,
835+
const TransitionModel &trans_model,
836+
const fst::SymbolTable *word_syms,
837+
std::string utt,
838+
double acoustic_scale,
839+
bool determinize,
840+
bool allow_partial,
841+
Int32VectorWriter *alignment_writer,
842+
Int32VectorWriter *words_writer,
843+
CompactLatticeWriter *compact_lattice_writer,
844+
LatticeWriter *lattice_writer,
845+
double *like_ptr);
846+
847+
549848
} // end namespace kaldi.

src/decoder/decoder-wrappers.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "itf/options-itf.h"
2424
#include "decoder/lattice-faster-decoder.h"
2525
#include "decoder/lattice-simple-decoder.h"
26+
#include "decoder/lattice-faster-decoder-combine.h"
2627

2728
// This header contains declarations from various convenience functions that are called
2829
// from binary-level programs such as gmm-decode-faster.cc, gmm-align-compiled.cc, and
@@ -196,6 +197,78 @@ bool DecodeUtteranceLatticeSimple(
196197
double *like_ptr); // puts utterance's likelihood in like_ptr on success.
197198

198199

200+
// For lattice-faster-decoder-combine
201+
template <typename FST>
202+
bool DecodeUtteranceLatticeFasterCombine(
203+
LatticeFasterDecoderCombineTpl<FST> &decoder, // not const but is really an input.
204+
DecodableInterface &decodable, // not const but is really an input.
205+
const TransitionModel &trans_model,
206+
const fst::SymbolTable *word_syms,
207+
std::string utt,
208+
double acoustic_scale,
209+
bool determinize,
210+
bool allow_partial,
211+
Int32VectorWriter *alignments_writer,
212+
Int32VectorWriter *words_writer,
213+
CompactLatticeWriter *compact_lattice_writer,
214+
LatticeWriter *lattice_writer,
215+
double *like_ptr); // puts utterance's likelihood in like_ptr on success.
216+
217+
218+
class DecodeUtteranceLatticeFasterCombineClass {
219+
public:
220+
// Initializer sets various variables.
221+
// NOTE: we "take ownership" of "decoder" and "decodable". These
222+
// are deleted by the destructor. On error, "num_err" is incremented.
223+
DecodeUtteranceLatticeFasterCombineClass(
224+
LatticeFasterDecoderCombine *decoder,
225+
DecodableInterface *decodable,
226+
const TransitionModel &trans_model,
227+
const fst::SymbolTable *word_syms,
228+
std::string utt,
229+
BaseFloat acoustic_scale,
230+
bool determinize,
231+
bool allow_partial,
232+
Int32VectorWriter *alignments_writer,
233+
Int32VectorWriter *words_writer,
234+
CompactLatticeWriter *compact_lattice_writer,
235+
LatticeWriter *lattice_writer,
236+
double *like_sum, // on success, adds likelihood to this.
237+
int64 *frame_sum, // on success, adds #frames to this.
238+
int32 *num_done, // on success (including partial decode), increments this.
239+
int32 *num_err, // on failure, increments this.
240+
int32 *num_partial); // If partial decode (final-state not reached), increments this.
241+
void operator () (); // The decoding happens here.
242+
~DecodeUtteranceLatticeFasterCombineClass(); // Output happens here.
243+
private:
244+
// The following variables correspond to inputs:
245+
LatticeFasterDecoderCombine *decoder_;
246+
DecodableInterface *decodable_;
247+
const TransitionModel *trans_model_;
248+
const fst::SymbolTable *word_syms_;
249+
std::string utt_;
250+
BaseFloat acoustic_scale_;
251+
bool determinize_;
252+
bool allow_partial_;
253+
Int32VectorWriter *alignments_writer_;
254+
Int32VectorWriter *words_writer_;
255+
CompactLatticeWriter *compact_lattice_writer_;
256+
LatticeWriter *lattice_writer_;
257+
double *like_sum_;
258+
int64 *frame_sum_;
259+
int32 *num_done_;
260+
int32 *num_err_;
261+
int32 *num_partial_;
262+
263+
// The following variables are stored by the computation.
264+
bool computed_; // operator () was called.
265+
bool success_; // decoding succeeded (possibly partial)
266+
bool partial_; // decoding was partial.
267+
CompactLattice *clat_; // Stored output, if determinize_ == true.
268+
Lattice *lat_; // Stored output, if determinize_ == false.
269+
};
270+
271+
199272

200273
} // end namespace kaldi.
201274

0 commit comments

Comments
 (0)