Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pybind/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ nnet3/nnet3_pybind.cc \
nnet3/nnet_chain_example_pybind.cc \
nnet3/nnet_common_pybind.cc \
nnet3/nnet_example_pybind.cc \
nnet3/nnet_nnet_pybind.cc \
tests/test_dlpack_subvector.cc \
util/kaldi_holder_pybind.cc \
util/kaldi_io_pybind.cc \
Expand Down
16 changes: 16 additions & 0 deletions src/pybind/kaldi/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ def read_transition_model(rxfilename):
ki.Close()

return trans_model


def read_nnet_model(rxfilename):
'''Read nnet model from an rxfilename.
'''
ki = kaldi_pybind.Input()
is_opened, is_binary = ki.Open(rxfilename, read_header=True)
if not is_opened:
raise FileNotOpenException('Failed to open {}'.format(rxfilename))

nnet = kaldi_pybind.nnet3.Nnet()
nnet.Read(ki.Stream(), is_binary)

ki.Close()

return nnet
2 changes: 2 additions & 0 deletions src/pybind/nnet3/nnet3_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
#include "nnet3/nnet_chain_example_pybind.h"
#include "nnet3/nnet_common_pybind.h"
#include "nnet3/nnet_example_pybind.h"
#include "nnet3/nnet_nnet_pybind.h"

void pybind_nnet3(py::module& _m) {
py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi");

pybind_nnet_common(m);
pybind_nnet_example(m);
pybind_nnet_chain_example(m);
pybind_nnet_nnet(m);
}
49 changes: 49 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// pybind/nnet3/nnet_nnet_pybind.cc

// Copyright 2019 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_nnet_pybind.h"

#include "nnet3/nnet-nnet.h"
#include "nnet3/nnet-component-itf.h"

using namespace kaldi;
using namespace kaldi::nnet3;

void pybind_nnet_nnet(py::module& m) {
using PyClass = Nnet;
py::class_<PyClass>(
m, "Nnet",
"This function can be used either to initialize a new Nnet from a config"
"file, or to add to an existing Nnet, possibly replacing certain parts of"
"it. It will die with error if something went wrong."
"Also see the function ReadEditConfig() in nnet-utils.h (it's made a"
"non-member because it doesn't need special access)."
)
.def(py::init<>())
.def("Read", &PyClass::Read, py::arg("is"), py::arg("binary"))
.def("GetComponentNames", &PyClass::GetComponentNames,
"returns vector of component names (needed by some parsing code, for instance).",
py::return_value_policy::reference)
.def("GetComponentName", &PyClass::GetComponentName, py::arg("component_index"))
.def("Info", &PyClass::Info)
.def("NumComponents", &PyClass::NumComponents)
.def("NumNodes", &PyClass::NumNodes)
// .def("GetComponent", &PyClass::GetComponent, py::arg("c"))
;
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_nnet_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_nnet(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_