Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/nnet3/nnet-convolutional-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ class TdnnComponent: public UpdatableComponent {

CuMatrixBase<BaseFloat> &LinearParams() { return linear_params_; }

const CuMatrix<BaseFloat> &Linearparams() const { return linear_params_; }

// This allows you to resize the vector in order to add a bias where
// there previously was none-- obviously this should be done carefully.
CuVector<BaseFloat> &BiasParams() { return bias_params_; }
Expand Down
18 changes: 18 additions & 0 deletions src/nnet3/nnet-normalize-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,24 @@ void BatchNormComponent::Write(std::ostream &os, bool binary) const {
WriteToken(os, binary, "</BatchNormComponent>");
}

CuVector<BaseFloat> BatchNormComponent::Mean() {
CuVector<BaseFloat> mean(stats_sum_);
if (count_ != 0) {
mean.Scale(1.0 / count_);
}
return mean;
}

CuVector<BaseFloat> BatchNormComponent::Var() {
CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
if (count_ != 0) {
mean.Scale(1.0 / count_);
var.Scale(1.0 / count_);
var.AddVecVec(-1.0, mean, mean, 1.0);
}
return var;
}

void BatchNormComponent::Scale(BaseFloat scale) {
if (scale == 0) {
count_ = 0.0;
Expand Down
5 changes: 5 additions & 0 deletions src/nnet3/nnet-normalize-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ class BatchNormComponent: public Component {
const CuVector<BaseFloat> &Offset() const { return offset_; }
const CuVector<BaseFloat> &Scale() const { return scale_; }

CuVector<BaseFloat> Mean();
CuVector<BaseFloat> Var();
double Count() { return count_; }
BaseFloat Eps() { return epsilon_; }

private:

struct Memo {
Expand Down
1 change: 1 addition & 0 deletions src/nnet3/nnet-simple-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ class LinearComponent: public UpdatableComponent {
BaseFloat OrthonormalConstraint() const { return orthonormal_constraint_; }
CuMatrixBase<BaseFloat> &Params() { return params_; }
const CuMatrixBase<BaseFloat> &Params() const { return params_; }
const CuMatrix<BaseFloat> &Params2() const { return params_; }
private:

// disallow assignment operator.
Expand Down
5 changes: 5 additions & 0 deletions src/pybind/Makefile
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ matrix/sparse_matrix_pybind.cc \
nnet3/nnet3_pybind.cc \
nnet3/nnet_chain_example_pybind.cc \
nnet3/nnet_common_pybind.cc \
nnet3/nnet_component_itf_pybind.cc \
nnet3/nnet_convolutional_component_pybind.cc \
nnet3/nnet_example_pybind.cc \
nnet3/nnet_nnet_pybind.cc \
nnet3/nnet_normalize_component_pybind.cc \
nnet3/nnet_simple_component_pybind.cc \
tests/test_dlpack_subvector.cc \
util/kaldi_holder_pybind.cc \
util/kaldi_io_pybind.cc \
Expand Down
16 changes: 16 additions & 0 deletions src/pybind/kaldi/io_util.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ def read_transition_model(rxfilename):
ki.Close()

return trans_model


def read_nnet3_model(rxfilename):
'''Read nnet model from an rxfilename.
'''
ki = kaldi_pybind.Input()
is_opened, is_binary = ki.Open(rxfilename, read_header=True)
if not is_opened:
raise FileNotOpenException('Failed to open {}'.format(rxfilename))

nnet = kaldi_pybind.nnet3.Nnet()
nnet.Read(ki.Stream(), is_binary)

ki.Close()

return nnet
1 change: 1 addition & 0 deletions src/pybind/nnet3/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

test:
python3 ./nnet_chain_example_pybind_test.py
python3 ./nnet_nnet_pybind_test.py

Binary file added src/pybind/nnet3/final.mdl
Binary file not shown.
11 changes: 11 additions & 0 deletions src/pybind/nnet3/nnet3_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

// Copyright 2019 Mobvoi AI Lab, Beijing, China
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
Expand All @@ -22,12 +23,22 @@

#include "nnet3/nnet_chain_example_pybind.h"
#include "nnet3/nnet_common_pybind.h"
#include "nnet3/nnet_component_itf_pybind.h"
#include "nnet3/nnet_convolutional_component_pybind.h"
#include "nnet3/nnet_example_pybind.h"
#include "nnet3/nnet_nnet_pybind.h"
#include "nnet3/nnet_normalize_component_pybind.h"
#include "nnet3/nnet_simple_component_pybind.h"

void pybind_nnet3(py::module& _m) {
py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi");

pybind_nnet_common(m);
pybind_nnet_component_itf(m);
pybind_nnet_convolutional_component(m);
pybind_nnet_example(m);
pybind_nnet_chain_example(m);
pybind_nnet_nnet(m);
pybind_nnet_normalize_component(m);
pybind_nnet_simple_component(m);
}
Empty file modified src/pybind/nnet3/nnet_chain_example_pybind_test.py
100755 → 100644
Empty file.
40 changes: 40 additions & 0 deletions src/pybind/nnet3/nnet_component_itf_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// pybind/nnet3/nnet_component_itf_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_component_itf_pybind.h"

#include "nnet3/nnet-component-itf.h"

using namespace kaldi::nnet3;

void pybind_nnet_component_itf(py::module& m) {
using PyClass = Component;
py::class_<PyClass>(m, "Component",
"Abstract base-class for neural-net components.")
.def("Type", &PyClass::Type,
"Returns a string such as \"SigmoidComponent\", describing the "
"type of the object.")
.def("Info", &PyClass::Info,
"Returns some text-form information about this component, for "
"diagnostics. Starts with the type of the component. E.g. "
"\"SigmoidComponent dim=900\", although most components will have "
"much more info.")
.def_static("NewComponentOfType", &PyClass::NewComponentOfType,
py::return_value_policy::take_ownership);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_component_itf_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_component_itf_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_component_itf(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
33 changes: 33 additions & 0 deletions src/pybind/nnet3/nnet_convolutional_component_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// pybind/nnet3/nnet_convolutional_component_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_convolutional_component_pybind.h"

#include "nnet3/nnet-convolutional-component.h"

using namespace kaldi::nnet3;

void pybind_nnet_convolutional_component(py::module& m) {
using TC = kaldi::nnet3::TdnnComponent;
py::class_<TC, Component>(m, "TdnnComponent")
.def("LinearParams", &TC::Linearparams,
py::return_value_policy::reference)
.def("BiasParams", &TC::BiasParams,
py::return_value_policy::reference);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_convolutional_component_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_convolutional_component_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_convolutional_component(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
52 changes: 52 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// pybind/nnet3/nnet_nnet_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_nnet_pybind.h"

#include "nnet3/nnet-nnet.h"

using namespace kaldi;
using namespace kaldi::nnet3;

void pybind_nnet_nnet(py::module& m) {
using PyClass = kaldi::nnet3::Nnet;
auto nnet = py::class_<PyClass>(
m, "Nnet",
"This function can be used either to initialize a new Nnet from a "
"config file, or to add to an existing Nnet, possibly replacing "
"certain parts of it. It will die with error if something went wrong. "
"Also see the function ReadEditConfig() in nnet-utils.h (it's made a "
"non-member because it doesn't need special access).");
nnet.def(py::init<>())
.def("Read", &PyClass::Read, py::arg("is"), py::arg("binary"))
.def("GetComponentNames", &PyClass::GetComponentNames,
"returns vector of component names (needed by some parsing code, "
"for instance).",
py::return_value_policy::reference)
.def("GetComponentName", &PyClass::GetComponentName,
py::arg("component_index"))
.def("Info", &PyClass::Info,
"returns some human-readable information about the network, "
"mostly for debugging purposes. Also see function NnetInfo() in "
"nnet-utils.h, which prints out more extensive infoformation.")
.def("NumComponents", &PyClass::NumComponents)
.def("NumNodes", &PyClass::NumNodes)
.def("GetComponent", (Component * (PyClass::*)(int32)) & PyClass::GetComponent,
py::arg("c"), py::return_value_policy::reference);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_nnet_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_nnet(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
Loading