Skip to content

Commit 20738f7

Browse files
fanlumegazone87
authored andcommitted
support loading kaldi model in python (kaldi-asr#3976)
* support load kaldi model in python * add some component * split one file to multi component wrap files * fix some bugs and add test mdl * add testmode func in batchnorm pybind * change StatsSum StatsSumsq to Mean Var * make const
1 parent 4b45385 commit 20738f7

21 files changed

+574
-0
lines changed

src/nnet3/nnet-convolutional-component.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ class TdnnComponent: public UpdatableComponent {
553553

554554
CuMatrixBase<BaseFloat> &LinearParams() { return linear_params_; }
555555

556+
const CuMatrix<BaseFloat> &Linearparams() const { return linear_params_; }
557+
556558
// This allows you to resize the vector in order to add a bias where
557559
// there previously was none-- obviously this should be done carefully.
558560
CuVector<BaseFloat> &BiasParams() { return bias_params_; }

src/nnet3/nnet-normalize-component.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,24 @@ void BatchNormComponent::Write(std::ostream &os, bool binary) const {
641641
WriteToken(os, binary, "</BatchNormComponent>");
642642
}
643643

644+
CuVector<BaseFloat> BatchNormComponent::Mean() const {
645+
CuVector<BaseFloat> mean(stats_sum_);
646+
if (count_ != 0) {
647+
mean.Scale(1.0 / count_);
648+
}
649+
return mean;
650+
}
651+
652+
CuVector<BaseFloat> BatchNormComponent::Var() const {
653+
CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
654+
if (count_ != 0) {
655+
mean.Scale(1.0 / count_);
656+
var.Scale(1.0 / count_);
657+
var.AddVecVec(-1.0, mean, mean, 1.0);
658+
}
659+
return var;
660+
}
661+
644662
void BatchNormComponent::Scale(BaseFloat scale) {
645663
if (scale == 0) {
646664
count_ = 0.0;

src/nnet3/nnet-normalize-component.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ class BatchNormComponent: public Component {
224224
const CuVector<BaseFloat> &Offset() const { return offset_; }
225225
const CuVector<BaseFloat> &Scale() const { return scale_; }
226226

227+
CuVector<BaseFloat> Mean() const;
228+
CuVector<BaseFloat> Var() const;
229+
double Count() const { return count_; }
230+
BaseFloat Eps() const { return epsilon_; }
231+
227232
private:
228233

229234
struct Memo {

src/nnet3/nnet-simple-component.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ class LinearComponent: public UpdatableComponent {
971971
BaseFloat OrthonormalConstraint() const { return orthonormal_constraint_; }
972972
CuMatrixBase<BaseFloat> &Params() { return params_; }
973973
const CuMatrixBase<BaseFloat> &Params() const { return params_; }
974+
const CuMatrix<BaseFloat> &Params2() const { return params_; }
974975
private:
975976

976977
// disallow assignment operator.

src/pybind/Makefile

100644100755
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ matrix/sparse_matrix_pybind.cc \
9090
nnet3/nnet3_pybind.cc \
9191
nnet3/nnet_chain_example_pybind.cc \
9292
nnet3/nnet_common_pybind.cc \
93+
nnet3/nnet_component_itf_pybind.cc \
94+
nnet3/nnet_convolutional_component_pybind.cc \
9395
nnet3/nnet_example_pybind.cc \
96+
nnet3/nnet_nnet_pybind.cc \
97+
nnet3/nnet_normalize_component_pybind.cc \
98+
nnet3/nnet_simple_component_pybind.cc \
9499
tests/test_dlpack_subvector.cc \
95100
util/kaldi_holder_pybind.cc \
96101
util/kaldi_io_pybind.cc \

src/pybind/kaldi/io_util.py

100644100755
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,19 @@ def read_transition_model(rxfilename):
8282
ki.Close()
8383

8484
return trans_model
85+
86+
87+
def read_nnet3_model(rxfilename):
88+
'''Read nnet model from an rxfilename.
89+
'''
90+
ki = kaldi_pybind.Input()
91+
is_opened, is_binary = ki.Open(rxfilename, read_header=True)
92+
if not is_opened:
93+
raise FileNotOpenException('Failed to open {}'.format(rxfilename))
94+
95+
nnet = kaldi_pybind.nnet3.Nnet()
96+
nnet.Read(ki.Stream(), is_binary)
97+
98+
ki.Close()
99+
100+
return nnet

src/pybind/nnet3/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

22
test:
33
python3 ./nnet_chain_example_pybind_test.py
4+
python3 ./nnet_nnet_pybind_test.py
45

src/pybind/nnet3/final.mdl

322 KB
Binary file not shown.

src/pybind/nnet3/nnet3_pybind.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
// Copyright 2019 Mobvoi AI Lab, Beijing, China
44
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)
5+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
56

67
// See ../../../COPYING for clarification regarding multiple authors
78
//
@@ -22,12 +23,22 @@
2223

2324
#include "nnet3/nnet_chain_example_pybind.h"
2425
#include "nnet3/nnet_common_pybind.h"
26+
#include "nnet3/nnet_component_itf_pybind.h"
27+
#include "nnet3/nnet_convolutional_component_pybind.h"
2528
#include "nnet3/nnet_example_pybind.h"
29+
#include "nnet3/nnet_nnet_pybind.h"
30+
#include "nnet3/nnet_normalize_component_pybind.h"
31+
#include "nnet3/nnet_simple_component_pybind.h"
2632

2733
void pybind_nnet3(py::module& _m) {
2834
py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi");
2935

3036
pybind_nnet_common(m);
37+
pybind_nnet_component_itf(m);
38+
pybind_nnet_convolutional_component(m);
3139
pybind_nnet_example(m);
3240
pybind_nnet_chain_example(m);
41+
pybind_nnet_nnet(m);
42+
pybind_nnet_normalize_component(m);
43+
pybind_nnet_simple_component(m);
3344
}

src/pybind/nnet3/nnet_chain_example_pybind_test.py

100755100644
File mode changed.

0 commit comments

Comments
 (0)