Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/nnet3/nnet-normalize-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class BatchNormComponent: public Component {
const CuVector<BaseFloat> &Offset() const { return offset_; }
const CuVector<BaseFloat> &Scale() const { return scale_; }

virtual const CuVector<double> &StatsSum() const { return stats_sum_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use

const CuVector<double> &StatsSum() const override { return stats_sum_; }

Replace virtual with override .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you make it a virtual method?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no method of name StatsSum in base calss, Should I use keyword override?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

Then please remove virtual.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csukuangfj

  CuVector<double> &StatsSum() { return stats_sum_; }
  CuVector<double> &StatsSumsq() { return stats_sumsq_; }
  double Count() { return count_; }

I have a problem here

TypeError: Unable to convert function return value to a Python type! The signature was
        (self: kaldi_pybind.nnet3.BatchNormComponent) -> kaldi::CuVector<double>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use double for any matrix types in Kaldi for pybind.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but stats_sum_'s and stats_sumsq_ 's type are CuVector
and offset_ scale_ also give me this bug

print(component.Offset().to_dlpack())
<capsule object "dltensor" at 0x7f28404c1780>

print(from_dlpack(component.Offset().to_dlpack()))
RuntimeError: CUDA error: invalid argument

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You do not need stats_sum_ and stats_sumsq_ in inference mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you invoked SetTestMode(true) before
print(from_dlpack(component.Offset().to_dlpack())) ?

Please re-check your code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the StatsMean and StatsVar in kaldi's model are equivalent to running_mean, running_var in pytorch, respectively.
but in kaldi's Read and Write function, there are some other operator, like

Read
stats_sumsq_.AddVecVec(1.0, stats_sum_, stats_sum_, 1.0);
stats_sum_.Scale(count_);
stats_sumsq_.Scale(count_);
Write
CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
  if (count_ != 0) {
    mean.Scale(1.0 / count_);
    var.Scale(1.0 / count_);
    var.AddVecVec(-1.0, mean, mean, 1.0);
  }

these operator are not necessary for pytorch's model, just use mean and var rather than offset and scale.

virtual const CuVector<double> &StatsSumsq() const { return stats_sumsq_; }
virtual const double &Count() const { return count_; }

private:

struct Memo {
Expand Down
5 changes: 5 additions & 0 deletions src/pybind/Makefile
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ matrix/sparse_matrix_pybind.cc \
nnet3/nnet3_pybind.cc \
nnet3/nnet_chain_example_pybind.cc \
nnet3/nnet_common_pybind.cc \
nnet3/nnet_component_itf_pybind.cc \
nnet3/nnet_convolutional_component_pybind.cc \
nnet3/nnet_example_pybind.cc \
nnet3/nnet_nnet_pybind.cc \
nnet3/nnet_normalize_component_pybind.cc \
nnet3/nnet_simple_component_pybind.cc \
tests/test_dlpack_subvector.cc \
util/kaldi_holder_pybind.cc \
util/kaldi_io_pybind.cc \
Expand Down
7 changes: 5 additions & 2 deletions src/pybind/cudamatrix/cu_matrix_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ void pybind_cu_matrix(py::module& m) {
.def("Set", &PyClass::Set, py::arg("value"))
.def("Add", &PyClass::Add, py::arg("value"))
.def("Scale", &PyClass::Scale, py::arg("value"))
.def("to_dlpack",
[](py::object obj) { return CuMatrixToDLPack(&obj); })
.def("__getitem__",
[](const PyClass& m, std::pair<ssize_t, ssize_t> i) {
return m(i.first, i.second);
Expand All @@ -55,8 +57,9 @@ void pybind_cu_matrix(py::module& m) {
py::arg("MatrixStrideType") = kDefaultStride)
.def(py::init<const MatrixBase<float>&, MatrixTransposeType>(),
py::arg("other"), py::arg("trans") = kNoTrans)
.def("to_dlpack",
[](py::object obj) { return CuMatrixToDLPack(&obj); });
// .def("to_dlpack",
// [](py::object obj) { return CuMatrixToDLPack(&obj); })
;
}
{
using PyClass = CuSubMatrix<float>;
Expand Down
2 changes: 1 addition & 1 deletion src/pybind/dlpack/dlpack_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ py::capsule CuVectorToDLPack(py::object* obj) {
}

py::capsule CuMatrixToDLPack(py::object* obj) {
auto* m = obj->cast<CuMatrix<float>*>();
auto* m = obj->cast<CuMatrixBase<float>*>();
#if HAVE_CUDA == 1
KALDI_ASSERT(CuDevice::Instantiate().Enabled());

Expand Down
16 changes: 16 additions & 0 deletions src/pybind/kaldi/io_util.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ def read_transition_model(rxfilename):
ki.Close()

return trans_model


def read_nnet3_model(filename):
'''Read nnet model from an filename.
'''
ki = kaldi_pybind.Input()
is_opened, is_binary = ki.Open(filename, read_header=True)
if not is_opened:
raise FileNotOpenException('Failed to open {}'.format(filename))

nnet = kaldi_pybind.nnet3.Nnet()
nnet.Read(ki.Stream(), is_binary)

ki.Close()

return nnet
10 changes: 10 additions & 0 deletions src/pybind/nnet3/nnet3_pybind.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,22 @@

#include "nnet3/nnet_chain_example_pybind.h"
#include "nnet3/nnet_common_pybind.h"
#include "nnet3/nnet_component_itf_pybind.h"
#include "nnet3/nnet_convolutional_component_pybind.h"
#include "nnet3/nnet_example_pybind.h"
#include "nnet3/nnet_nnet_pybind.h"
#include "nnet3/nnet_normalize_component_pybind.h"
#include "nnet3/nnet_simple_component_pybind.h"

void pybind_nnet3(py::module& _m) {
py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi");

pybind_nnet_common(m);
pybind_nnet_component_itf(m);
pybind_nnet_convolutional_component(m);
pybind_nnet_example(m);
pybind_nnet_chain_example(m);
pybind_nnet_nnet(m);
pybind_nnet_normalize_component(m);
pybind_nnet_simple_component(m);
}
40 changes: 40 additions & 0 deletions src/pybind/nnet3/nnet_component_itf_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// pybind/nnet3/nnet_component_itf_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_component_itf_pybind.h"

#include "nnet3/nnet-component-itf.h"

using namespace kaldi::nnet3;

void pybind_nnet_component_itf(py::module& m) {
using PyClass = Component;
py::class_<PyClass>(m, "Component",
"Abstract base-class for neural-net components.")
.def("Type", &PyClass::Type,
"Returns a string such as \"SigmoidComponent\", describing the "
"type of the object.")
.def("Info", &PyClass::Info,
"Returns some text-form information about this component, for "
"diagnostics. Starts with the type of the component. E.g. "
"\"SigmoidComponent dim=900\", although most components will have "
"much more info.")
.def_static("NewComponentOfType", &PyClass::NewComponentOfType,
py::return_value_policy::take_ownership);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_component_itf_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_component_itf_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_component_itf(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
32 changes: 32 additions & 0 deletions src/pybind/nnet3/nnet_convolutional_component_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// pybind/nnet3/nnet_convolutional_component_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_convolutional_component_pybind.h"

#include "nnet3/nnet-convolutional-component.h"

using namespace kaldi::nnet3;

void pybind_nnet_convolutional_component(py::module& m) {
using TC = kaldi::nnet3::TdnnComponent;
py::class_<TC>(m, "TdnnComponent")
.def("Type", &TC::Type)
.def("LinearParams", &TC::LinearParams, py::return_value_policy::reference)
.def("BiasParams", &TC::BiasParams, py::return_value_policy::reference);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_convolutional_component_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_convolutional_component_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_convolutional_component(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
52 changes: 52 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// pybind/nnet3/nnet_nnet_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_nnet_pybind.h"

#include "nnet3/nnet-nnet.h"

using namespace kaldi;
using namespace kaldi::nnet3;

void pybind_nnet_nnet(py::module& m) {
using PyClass = kaldi::nnet3::Nnet;
auto nnet = py::class_<PyClass>(
m, "Nnet",
"This function can be used either to initialize a new Nnet from a "
"config file, or to add to an existing Nnet, possibly replacing "
"certain parts of it. It will die with error if something went wrong. "
"Also see the function ReadEditConfig() in nnet-utils.h (it's made a "
"non-member because it doesn't need special access).");
nnet.def(py::init<>())
.def("Read", &PyClass::Read, py::arg("is"), py::arg("binary"))
.def("GetComponentNames", &PyClass::GetComponentNames,
"returns vector of component names (needed by some parsing code, "
"for instance).",
py::return_value_policy::reference)
.def("GetComponentName", &PyClass::GetComponentName,
py::arg("component_index"))
.def("Info", &PyClass::Info,
"returns some human-readable information about the network, "
"mostly for debugging purposes. Also see function NnetInfo() in "
"nnet-utils.h, which prints out more extensive infoformation.")
.def("NumComponents", &PyClass::NumComponents)
.def("NumNodes", &PyClass::NumNodes)
.def("GetComponent", (Component * (PyClass::*)(int32)) & PyClass::GetComponent,
py::arg("c"), py::return_value_policy::reference);
}
25 changes: 25 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/nnet3/nnet_nnet_pybind.h

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
#define KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_nnet_nnet(py::module& m);

#endif // KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
40 changes: 40 additions & 0 deletions src/pybind/nnet3/nnet_nnet_pybind_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3

# Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
# Apache 2.0

import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir))

import unittest

import kaldi
from kaldi import read_nnet3_model
from torch.utils.dlpack import from_dlpack
from torch.utils.dlpack import to_dlpack

class TestNnetNnet(unittest.TestCase):

def test_nnet_nnet(self):
kaldi.SelectGpuId('yes')
final_mdl = "/mnt/cfs1_alias1/asr/users/fanlu/task/kaldi_recipe/pybind/s10.1/exp/chain_cleaned_1c/tdnn1c_sp/final.mdl"
nnet = kaldi.read_nnet3_model(final_mdl)
for i in range(nnet.NumComponents()):
component = nnet.GetComponent(i)
comp_type = component.Type()
if "Affine" in comp_type or "TdnnComponent" in comp_type:
linear_params = from_dlpack(component.LinearParams().to_dlpack())
bias_params = from_dlpack(component.BiasParams().to_dlpack())
print(linear_params.shape)
elif "Batch" in comp_type:
# stats_sum = from_dlpack(component.StatsSum().to_dlpack())
# stats_sumsq = from_dlpack(component.StatsSumsq().to_dlpack())
# print(stats_sum.shape)
pass
elif "LinearComponent" == comp_type:
linear_params = from_dlpack(component.LinearParams().to_dlpack())
print(linear_params.shape)

if __name__ == '__main__':
unittest.main()
35 changes: 35 additions & 0 deletions src/pybind/nnet3/nnet_normalize_component_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// pybind/nnet3/nnet_normalize_component_pybind.cc

// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet_normalize_component_pybind.h"

#include "nnet3/nnet-normalize-component.h"

using namespace kaldi::nnet3;

void pybind_nnet_normalize_component(py::module& m) {
using PyClass = kaldi::nnet3::BatchNormComponent;
py::class_<PyClass>(m, "BatchNormComponent")
.def("Type", &PyClass::Type)
.def("StatsSum", &PyClass::StatsSum)
.def("StatsSumsq", &PyClass::StatsSumsq)
.def("Count", &PyClass::Count)
.def("Offset", &PyClass::Offset)
.def("Scale", overload_cast_<>()(&PyClass::Scale, py::const_));
}
Loading