Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pybind/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ nnet3/nnet3_pybind.cc \
nnet3/nnet_chain_example_pybind.cc \
nnet3/nnet_common_pybind.cc \
nnet3/nnet_example_pybind.cc \
nnet3/nnet_nnet_pybind.cc \
tests/test_dlpack_subvector.cc \
util/kaldi_holder_pybind.cc \
util/kaldi_io_pybind.cc \
Expand Down
16 changes: 16 additions & 0 deletions src/pybind/kaldi/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ def read_transition_model(rxfilename):
ki.Close()

return trans_model


def read_nnet_model(rxfilename):
'''Read nnet model from an rxfilename.
'''
ki = kaldi_pybind.Input()
is_opened, is_binary = ki.Open(rxfilename, read_header=True)
if not is_opened:
raise FileNotOpenException('Failed to open {}'.format(rxfilename))

nnet = kaldi_pybind.nnet3.Nnet()
nnet.Read(ki.Stream(), is_binary)

ki.Close()

return nnet
2 changes: 2 additions & 0 deletions src/pybind/nnet3/nnet3_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
#include "nnet3/nnet_chain_example_pybind.h"
#include "nnet3/nnet_common_pybind.h"
#include "nnet3/nnet_example_pybind.h"
#include "nnet3/nnet_nnet_pybind.h"

void pybind_nnet3(py::module& _m) {
py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi");

pybind_nnet_common(m);
pybind_nnet_example(m);
pybind_nnet_chain_example(m);
pybind_nnet_nnet(m);
}
108 changes: 108 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// pybind/nnet3/nnet_nnet_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_nnet_pybind.h"

#include "nnet3/nnet-component-itf.h"
#include "nnet3/nnet-convolutional-component.h"
#include "nnet3/nnet-nnet.h"
#include "nnet3/nnet-normalize-component.h"
#include "nnet3/nnet-simple-component.h"

using namespace kaldi;
using namespace kaldi::nnet3;

template <typename... Args>
using overload_cast_ = py::detail::overload_cast_impl<Args...>;

void pybind_nnet_nnet(py::module& m) {
using Comp = kaldi::nnet3::Component;
py::class_<Comp>(m, "Component",
"Abstract base-class for neural-net components.")
.def("Type", &Comp::Type,
"Returns a string such as \"SigmoidComponent\", describing the "
"type of the object.")
.def("Info", &Comp::Info,
"Returns some text-form information about this component, for "
"diagnostics. Starts with the type of the component. E.g. "
"\"SigmoidComponent dim=900\", although most components will have "
"much more info.")
.def_static("NewComponentOfType", &Comp::NewComponentOfType,
py::return_value_policy::take_ownership);

using BNC = kaldi::nnet3::BatchNormComponent;
py::class_<BNC>(m, "BatchNormComponent")
.def("Type", &BNC::Type)
.def("Offset", &BNC::Offset)
.def("Scale", overload_cast_<>()(&BNC::Scale, py::const_));

using FAC = kaldi::nnet3::FixedAffineComponent;
py::class_<FAC>(m, "FixedAffineComponent")
.def("Type", &FAC::Type)
.def("LinearParams", &FAC::LinearParams)
.def("BiasParams", &FAC::BiasParams);

using LC = kaldi::nnet3::LinearComponent;
py::class_<LC>(m, "LinearComponent")
.def("Type", &LC::Type)
.def("Params", overload_cast_<>()(&LC::Params, py::const_));

using NGAC = kaldi::nnet3::NaturalGradientAffineComponent;
py::class_<NGAC>(m, "NaturalGradientAffineComponent")
.def("Type", &NGAC::Type)
.def("LinearParams", overload_cast_<>()(&NGAC::LinearParams, py::const_))
.def("BiasParams", overload_cast_<>()(&NGAC::BiasParams, py::const_));

using AC = kaldi::nnet3::AffineComponent;
py::class_<AC>(m, "AffineComponent")
.def("Type", &AC::Type)
.def("LinearParams", overload_cast_<>()(&AC::LinearParams, py::const_))
.def("BiasParams", overload_cast_<>()(&AC::BiasParams, py::const_));

using TC = kaldi::nnet3::TdnnComponent;
py::class_<TC>(m, "TdnnComponent")
.def("Type", &TC::Type)
.def("LinearParams", &TC::LinearParams)
.def("BiasParams", &TC::BiasParams);

using PyClass = kaldi::nnet3::Nnet;
auto nnet = py::class_<PyClass>(
m, "Nnet",
"This function can be used either to initialize a new Nnet from a "
"config file, or to add to an existing Nnet, possibly replacing "
"certain parts of it. It will die with error if something went wrong. "
"Also see the function ReadEditConfig() in nnet-utils.h (it's made a "
"non-member because it doesn't need special access).");
nnet.def(py::init<>())
.def("Read", &PyClass::Read, py::arg("is"), py::arg("binary"))
.def("GetComponentNames", &PyClass::GetComponentNames,
"returns vector of component names (needed by some parsing code, "
"for instance).",
py::return_value_policy::reference)
.def("GetComponentName", &PyClass::GetComponentName,
py::arg("component_index"))
.def("Info", &PyClass::Info,
"returns some human-readable information about the network, "
"mostly for debugging purposes. Also see function NnetInfo() in "
"nnet-utils.h, which prints out more extensive infoformation.")
.def("NumComponents", &PyClass::NumComponents)
.def("NumNodes", &PyClass::NumNodes)
.def("GetComponent", (Comp * (PyClass::*)(int32)) & PyClass::GetComponent,
py::arg("c"), py::return_value_policy::reference);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_nnet_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_nnet(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_