Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/nnet3/nnet-convolutional-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ class TdnnComponent: public UpdatableComponent {

CuMatrixBase<BaseFloat> &LinearParams() { return linear_params_; }

const CuMatrix<BaseFloat> &Linearparams() const { return linear_params_; }

// This allows you to resize the vector in order to add a bias where
// there previously was none-- obviously this should be done carefully.
CuVector<BaseFloat> &BiasParams() { return bias_params_; }
Expand Down
18 changes: 18 additions & 0 deletions src/nnet3/nnet-normalize-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,24 @@ void BatchNormComponent::Write(std::ostream &os, bool binary) const {
WriteToken(os, binary, "</BatchNormComponent>");
}

CuVector<BaseFloat> BatchNormComponent::Mean() const {
CuVector<BaseFloat> mean(stats_sum_);
if (count_ != 0) {
mean.Scale(1.0 / count_);
}
return mean;
}

CuVector<BaseFloat> BatchNormComponent::Var() const {
CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
if (count_ != 0) {
mean.Scale(1.0 / count_);
var.Scale(1.0 / count_);
var.AddVecVec(-1.0, mean, mean, 1.0);
}
return var;
}

void BatchNormComponent::Scale(BaseFloat scale) {
if (scale == 0) {
count_ = 0.0;
Expand Down
5 changes: 5 additions & 0 deletions src/nnet3/nnet-normalize-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ class BatchNormComponent: public Component {
const CuVector<BaseFloat> &Offset() const { return offset_; }
const CuVector<BaseFloat> &Scale() const { return scale_; }

CuVector<BaseFloat> Mean() const;
CuVector<BaseFloat> Var() const;
double Count() const { return count_; }
BaseFloat Eps() const { return epsilon_; }

private:

struct Memo {
Expand Down
1 change: 1 addition & 0 deletions src/nnet3/nnet-simple-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ class LinearComponent: public UpdatableComponent {
BaseFloat OrthonormalConstraint() const { return orthonormal_constraint_; }
CuMatrixBase<BaseFloat> &Params() { return params_; }
const CuMatrixBase<BaseFloat> &Params() const { return params_; }
const CuMatrix<BaseFloat> &Params2() const { return params_; }
private:

// disallow assignment operator.
Expand Down
5 changes: 5 additions & 0 deletions src/pybind/Makefile
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ matrix/sparse_matrix_pybind.cc \
nnet3/nnet3_pybind.cc \
nnet3/nnet_chain_example_pybind.cc \
nnet3/nnet_common_pybind.cc \
nnet3/nnet_component_itf_pybind.cc \
nnet3/nnet_convolutional_component_pybind.cc \
nnet3/nnet_example_pybind.cc \
nnet3/nnet_nnet_pybind.cc \
nnet3/nnet_normalize_component_pybind.cc \
nnet3/nnet_simple_component_pybind.cc \
tests/test_dlpack_subvector.cc \
util/kaldi_holder_pybind.cc \
util/kaldi_io_pybind.cc \
Expand Down
16 changes: 16 additions & 0 deletions src/pybind/kaldi/io_util.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ def read_transition_model(rxfilename):
ki.Close()

return trans_model


def read_nnet3_model(rxfilename):
'''Read nnet model from an rxfilename.
'''
ki = kaldi_pybind.Input()
is_opened, is_binary = ki.Open(rxfilename, read_header=True)
if not is_opened:
raise FileNotOpenException('Failed to open {}'.format(rxfilename))

nnet = kaldi_pybind.nnet3.Nnet()
nnet.Read(ki.Stream(), is_binary)

ki.Close()

return nnet
1 change: 1 addition & 0 deletions src/pybind/nnet3/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

test:
python3 ./nnet_chain_example_pybind_test.py
python3 ./nnet_nnet_pybind_test.py

Binary file added src/pybind/nnet3/final.mdl
Binary file not shown.
11 changes: 11 additions & 0 deletions src/pybind/nnet3/nnet3_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

// Copyright 2019 Mobvoi AI Lab, Beijing, China
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
Expand All @@ -22,12 +23,22 @@

#include "nnet3/nnet_chain_example_pybind.h"
#include "nnet3/nnet_common_pybind.h"
#include "nnet3/nnet_component_itf_pybind.h"
#include "nnet3/nnet_convolutional_component_pybind.h"
#include "nnet3/nnet_example_pybind.h"
#include "nnet3/nnet_nnet_pybind.h"
#include "nnet3/nnet_normalize_component_pybind.h"
#include "nnet3/nnet_simple_component_pybind.h"

void pybind_nnet3(py::module& _m) {
py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi");

pybind_nnet_common(m);
pybind_nnet_component_itf(m);
pybind_nnet_convolutional_component(m);
pybind_nnet_example(m);
pybind_nnet_chain_example(m);
pybind_nnet_nnet(m);
pybind_nnet_normalize_component(m);
pybind_nnet_simple_component(m);
}
Empty file modified src/pybind/nnet3/nnet_chain_example_pybind_test.py
100755 → 100644
Empty file.
40 changes: 40 additions & 0 deletions src/pybind/nnet3/nnet_component_itf_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// pybind/nnet3/nnet_component_itf_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_component_itf_pybind.h"

#include "nnet3/nnet-component-itf.h"

using namespace kaldi::nnet3;

void pybind_nnet_component_itf(py::module& m) {
using PyClass = Component;
py::class_<PyClass>(m, "Component",
"Abstract base-class for neural-net components.")
.def("Type", &PyClass::Type,
"Returns a string such as \"SigmoidComponent\", describing the "
"type of the object.")
.def("Info", &PyClass::Info,
"Returns some text-form information about this component, for "
"diagnostics. Starts with the type of the component. E.g. "
"\"SigmoidComponent dim=900\", although most components will have "
"much more info.")
.def_static("NewComponentOfType", &PyClass::NewComponentOfType,
py::return_value_policy::take_ownership);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_component_itf_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_component_itf_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_component_itf(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
33 changes: 33 additions & 0 deletions src/pybind/nnet3/nnet_convolutional_component_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// pybind/nnet3/nnet_convolutional_component_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_convolutional_component_pybind.h"

#include "nnet3/nnet-convolutional-component.h"

using namespace kaldi::nnet3;

void pybind_nnet_convolutional_component(py::module& m) {
using TC = kaldi::nnet3::TdnnComponent;
py::class_<TC, Component>(m, "TdnnComponent")
.def("LinearParams", &TC::Linearparams,
py::return_value_policy::reference)
.def("BiasParams", &TC::BiasParams,
py::return_value_policy::reference);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_convolutional_component_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_convolutional_component_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_convolutional_component(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
52 changes: 52 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// pybind/nnet3/nnet_nnet_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_nnet_pybind.h"

#include "nnet3/nnet-nnet.h"

using namespace kaldi;
using namespace kaldi::nnet3;

void pybind_nnet_nnet(py::module& m) {
using PyClass = kaldi::nnet3::Nnet;
auto nnet = py::class_<PyClass>(
m, "Nnet",
"This function can be used either to initialize a new Nnet from a "
"config file, or to add to an existing Nnet, possibly replacing "
"certain parts of it. It will die with error if something went wrong. "
"Also see the function ReadEditConfig() in nnet-utils.h (it's made a "
"non-member because it doesn't need special access).");
nnet.def(py::init<>())
.def("Read", &PyClass::Read, py::arg("is"), py::arg("binary"))
.def("GetComponentNames", &PyClass::GetComponentNames,
"returns vector of component names (needed by some parsing code, "
"for instance).",
py::return_value_policy::reference)
.def("GetComponentName", &PyClass::GetComponentName,
py::arg("component_index"))
.def("Info", &PyClass::Info,
"returns some human-readable information about the network, "
"mostly for debugging purposes. Also see function NnetInfo() in "
"nnet-utils.h, which prints out more extensive infoformation.")
.def("NumComponents", &PyClass::NumComponents)
.def("NumNodes", &PyClass::NumNodes)
.def("GetComponent", (Component * (PyClass::*)(int32)) & PyClass::GetComponent,
py::arg("c"), py::return_value_policy::reference);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_nnet_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_nnet(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
Loading