diff --git a/src/nnet3/nnet-convolutional-component.h b/src/nnet3/nnet-convolutional-component.h index 279cec321dd..bdb8daf75bb 100644 --- a/src/nnet3/nnet-convolutional-component.h +++ b/src/nnet3/nnet-convolutional-component.h @@ -553,6 +553,8 @@ class TdnnComponent: public UpdatableComponent { CuMatrixBase &LinearParams() { return linear_params_; } + const CuMatrix &Linearparams() const { return linear_params_; } + // This allows you to resize the vector in order to add a bias where // there previously was none-- obviously this should be done carefully. CuVector &BiasParams() { return bias_params_; } diff --git a/src/nnet3/nnet-normalize-component.cc b/src/nnet3/nnet-normalize-component.cc index fdfd9544785..845631b09bf 100644 --- a/src/nnet3/nnet-normalize-component.cc +++ b/src/nnet3/nnet-normalize-component.cc @@ -641,6 +641,24 @@ void BatchNormComponent::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, ""); } +CuVector BatchNormComponent::Mean() const { + CuVector mean(stats_sum_); + if (count_ != 0) { + mean.Scale(1.0 / count_); + } + return mean; +} + +CuVector BatchNormComponent::Var() const { + CuVector mean(stats_sum_), var(stats_sumsq_); + if (count_ != 0) { + mean.Scale(1.0 / count_); + var.Scale(1.0 / count_); + var.AddVecVec(-1.0, mean, mean, 1.0); + } + return var; +} + void BatchNormComponent::Scale(BaseFloat scale) { if (scale == 0) { count_ = 0.0; diff --git a/src/nnet3/nnet-normalize-component.h b/src/nnet3/nnet-normalize-component.h index 37ad624d0f0..4f92dd02bc6 100644 --- a/src/nnet3/nnet-normalize-component.h +++ b/src/nnet3/nnet-normalize-component.h @@ -224,6 +224,11 @@ class BatchNormComponent: public Component { const CuVector &Offset() const { return offset_; } const CuVector &Scale() const { return scale_; } + CuVector Mean() const; + CuVector Var() const; + double Count() const { return count_; } + BaseFloat Eps() const { return epsilon_; } + private: struct Memo { diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index 546176f71ee..12514298961 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -971,6 +971,7 @@ class LinearComponent: public UpdatableComponent { BaseFloat OrthonormalConstraint() const { return orthonormal_constraint_; } CuMatrixBase &Params() { return params_; } const CuMatrixBase &Params() const { return params_; } + const CuMatrix &Params2() const { return params_; } private: // disallow assignment operator. diff --git a/src/pybind/Makefile b/src/pybind/Makefile old mode 100644 new mode 100755 index 7e3079b7d09..a1e8252c982 --- a/src/pybind/Makefile +++ b/src/pybind/Makefile @@ -90,7 +90,12 @@ matrix/sparse_matrix_pybind.cc \ nnet3/nnet3_pybind.cc \ nnet3/nnet_chain_example_pybind.cc \ nnet3/nnet_common_pybind.cc \ +nnet3/nnet_component_itf_pybind.cc \ +nnet3/nnet_convolutional_component_pybind.cc \ nnet3/nnet_example_pybind.cc \ +nnet3/nnet_nnet_pybind.cc \ +nnet3/nnet_normalize_component_pybind.cc \ +nnet3/nnet_simple_component_pybind.cc \ tests/test_dlpack_subvector.cc \ util/kaldi_holder_pybind.cc \ util/kaldi_io_pybind.cc \ diff --git a/src/pybind/kaldi/io_util.py b/src/pybind/kaldi/io_util.py old mode 100644 new mode 100755 index f26229a8661..158fbf19d26 --- a/src/pybind/kaldi/io_util.py +++ b/src/pybind/kaldi/io_util.py @@ -82,3 +82,19 @@ def read_transition_model(rxfilename): ki.Close() return trans_model + + +def read_nnet3_model(rxfilename): + '''Read nnet model from an rxfilename. + ''' + ki = kaldi_pybind.Input() + is_opened, is_binary = ki.Open(rxfilename, read_header=True) + if not is_opened: + raise FileNotOpenException('Failed to open {}'.format(rxfilename)) + + nnet = kaldi_pybind.nnet3.Nnet() + nnet.Read(ki.Stream(), is_binary) + + ki.Close() + + return nnet diff --git a/src/pybind/nnet3/Makefile b/src/pybind/nnet3/Makefile index d9fd9059847..6f023432a13 100644 --- a/src/pybind/nnet3/Makefile +++ b/src/pybind/nnet3/Makefile @@ -1,4 +1,5 @@ test: python3 ./nnet_chain_example_pybind_test.py + python3 ./nnet_nnet_pybind_test.py diff --git a/src/pybind/nnet3/final.mdl b/src/pybind/nnet3/final.mdl new file mode 100644 index 00000000000..8b5613d7fa2 Binary files /dev/null and b/src/pybind/nnet3/final.mdl differ diff --git a/src/pybind/nnet3/nnet3_pybind.cc b/src/pybind/nnet3/nnet3_pybind.cc index c0ccd5979cb..3397ee87144 100644 --- a/src/pybind/nnet3/nnet3_pybind.cc +++ b/src/pybind/nnet3/nnet3_pybind.cc @@ -2,6 +2,7 @@ // Copyright 2019 Mobvoi AI Lab, Beijing, China // (author: Fangjun Kuang, Yaguang Hu, Jian Wang) +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) // See ../../../COPYING for clarification regarding multiple authors // @@ -22,12 +23,22 @@ #include "nnet3/nnet_chain_example_pybind.h" #include "nnet3/nnet_common_pybind.h" +#include "nnet3/nnet_component_itf_pybind.h" +#include "nnet3/nnet_convolutional_component_pybind.h" #include "nnet3/nnet_example_pybind.h" +#include "nnet3/nnet_nnet_pybind.h" +#include "nnet3/nnet_normalize_component_pybind.h" +#include "nnet3/nnet_simple_component_pybind.h" void pybind_nnet3(py::module& _m) { py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi"); pybind_nnet_common(m); + pybind_nnet_component_itf(m); + pybind_nnet_convolutional_component(m); pybind_nnet_example(m); pybind_nnet_chain_example(m); + pybind_nnet_nnet(m); + pybind_nnet_normalize_component(m); + pybind_nnet_simple_component(m); } diff --git a/src/pybind/nnet3/nnet_chain_example_pybind_test.py b/src/pybind/nnet3/nnet_chain_example_pybind_test.py old mode 100755 new mode 100644 diff --git a/src/pybind/nnet3/nnet_component_itf_pybind.cc b/src/pybind/nnet3/nnet_component_itf_pybind.cc new file mode 100644 index 00000000000..66011d3af26 --- /dev/null +++ b/src/pybind/nnet3/nnet_component_itf_pybind.cc @@ -0,0 +1,40 @@ +// pybind/nnet3/nnet_component_itf_pybind.cc + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_component_itf_pybind.h" + +#include "nnet3/nnet-component-itf.h" + +using namespace kaldi::nnet3; + +void pybind_nnet_component_itf(py::module& m) { + using PyClass = Component; + py::class_(m, "Component", + "Abstract base-class for neural-net components.") + .def("Type", &PyClass::Type, + "Returns a string such as \"SigmoidComponent\", describing the " + "type of the object.") + .def("Info", &PyClass::Info, + "Returns some text-form information about this component, for " + "diagnostics. Starts with the type of the component. E.g. " + "\"SigmoidComponent dim=900\", although most components will have " + "much more info.") + .def_static("NewComponentOfType", &PyClass::NewComponentOfType, + py::return_value_policy::take_ownership); +} diff --git a/src/pybind/nnet3/nnet_component_itf_pybind.h b/src/pybind/nnet3/nnet_component_itf_pybind.h new file mode 100644 index 00000000000..662d372ae37 --- /dev/null +++ b/src/pybind/nnet3/nnet_component_itf_pybind.h @@ -0,0 +1,25 @@ +// pybind/nnet3/nnet_component_itf_pybind.h + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_ +#define KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_component_itf(py::module& m); + +#endif // KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_ diff --git a/src/pybind/nnet3/nnet_convolutional_component_pybind.cc b/src/pybind/nnet3/nnet_convolutional_component_pybind.cc new file mode 100644 index 00000000000..560c98257f6 --- /dev/null +++ b/src/pybind/nnet3/nnet_convolutional_component_pybind.cc @@ -0,0 +1,33 @@ +// pybind/nnet3/nnet_convolutional_component_pybind.cc + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_convolutional_component_pybind.h" + +#include "nnet3/nnet-convolutional-component.h" + +using namespace kaldi::nnet3; + +void pybind_nnet_convolutional_component(py::module& m) { + using TC = kaldi::nnet3::TdnnComponent; + py::class_(m, "TdnnComponent") + .def("LinearParams", &TC::Linearparams, + py::return_value_policy::reference) + .def("BiasParams", &TC::BiasParams, + py::return_value_policy::reference); +} diff --git a/src/pybind/nnet3/nnet_convolutional_component_pybind.h b/src/pybind/nnet3/nnet_convolutional_component_pybind.h new file mode 100644 index 00000000000..8903884ad2d --- /dev/null +++ b/src/pybind/nnet3/nnet_convolutional_component_pybind.h @@ -0,0 +1,25 @@ +// pybind/nnet3/nnet_convolutional_component_pybind.h + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_ +#define KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_convolutional_component(py::module& m); + +#endif // KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_ diff --git a/src/pybind/nnet3/nnet_nnet_pybind.cc b/src/pybind/nnet3/nnet_nnet_pybind.cc new file mode 100644 index 00000000000..8995d0cb100 --- /dev/null +++ b/src/pybind/nnet3/nnet_nnet_pybind.cc @@ -0,0 +1,52 @@ +// pybind/nnet3/nnet_nnet_pybind.cc + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_nnet_pybind.h" + +#include "nnet3/nnet-nnet.h" + +using namespace kaldi; +using namespace kaldi::nnet3; + +void pybind_nnet_nnet(py::module& m) { + using PyClass = kaldi::nnet3::Nnet; + auto nnet = py::class_( + m, "Nnet", + "This function can be used either to initialize a new Nnet from a " + "config file, or to add to an existing Nnet, possibly replacing " + "certain parts of it. It will die with error if something went wrong. " + "Also see the function ReadEditConfig() in nnet-utils.h (it's made a " + "non-member because it doesn't need special access)."); + nnet.def(py::init<>()) + .def("Read", &PyClass::Read, py::arg("is"), py::arg("binary")) + .def("GetComponentNames", &PyClass::GetComponentNames, + "returns vector of component names (needed by some parsing code, " + "for instance).", + py::return_value_policy::reference) + .def("GetComponentName", &PyClass::GetComponentName, + py::arg("component_index")) + .def("Info", &PyClass::Info, + "returns some human-readable information about the network, " + "mostly for debugging purposes. Also see function NnetInfo() in " + "nnet-utils.h, which prints out more extensive infoformation.") + .def("NumComponents", &PyClass::NumComponents) + .def("NumNodes", &PyClass::NumNodes) + .def("GetComponent", (Component * (PyClass::*)(int32)) & PyClass::GetComponent, + py::arg("c"), py::return_value_policy::reference); +} diff --git a/src/pybind/nnet3/nnet_nnet_pybind.h b/src/pybind/nnet3/nnet_nnet_pybind.h new file mode 100644 index 00000000000..6dcc45d8417 --- /dev/null +++ b/src/pybind/nnet3/nnet_nnet_pybind.h @@ -0,0 +1,25 @@ +// pybind/nnet3/nnet_nnet_pybind.h + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_ +#define KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_nnet(py::module& m); + +#endif // KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_ diff --git a/src/pybind/nnet3/nnet_nnet_pybind_test.py b/src/pybind/nnet3/nnet_nnet_pybind_test.py new file mode 100644 index 00000000000..5efa06058fa --- /dev/null +++ b/src/pybind/nnet3/nnet_nnet_pybind_test.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 + +# Copyright 2020 JD AI, Beijing, China (author: Lu Fan) +# Apache 2.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +try: + import torch + from torch.utils.dlpack import to_dlpack + from torch.utils.dlpack import from_dlpack +except ImportError: + print('This test needs PyTorch.') + print('Please install PyTorch first.') + print('PyTorch 1.3.0dev20191006 has been tested and is known to work.') + sys.exit(0) + +import kaldi + +""" +input dim=40 name=input + +# please note that it is important to have input layer with the name=input +# as the layer immediately preceding the fixed-affine-layer to enable +# the use of short notation for the descriptor +fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=$dir/configs/lda.mat + +# the first splicing is moved before the lda layer, so no splicing here +relu-batchnorm-dropout-layer name=tdnn1 $affine_opts dim=16 +tdnnf-layer name=tdnnf2 $tdnnf_opts dim=16 bottleneck-dim=2 time-stride=1 +linear-component name=prefinal-l dim=4 $linear_opts + +prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=16 small-dim=4 +output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + +prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=16 small-dim=4 +output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +""" +class TestNnetNnet(unittest.TestCase): + + def test_nnet_nnet(self): + if torch.cuda.is_available() == False: + print('No GPU detected! Skip it') + return + + if kaldi.CudaCompiled() == False: + print('Kaldi is not compiled with CUDA! Skip it') + return + + device_id = 0 + + # Kaldi and PyTorch will use the same GPU + kaldi.SelectGpuDevice(device_id=device_id) + kaldi.CuDeviceAllowMultithreading() + + final_mdl = 'final.mdl' + nnet = kaldi.read_nnet3_model(final_mdl) + for i in range(nnet.NumComponents()): + component = nnet.GetComponent(i) + comp_type = component.Type() + if comp_type in ['RectifiedLinearComponent', 'GeneralDropoutComponent', + 'NoOpComponent']: + continue + comp_name = nnet.GetComponentName(i) + if comp_name == 'lda': + self.assertEqual(comp_type, 'FixedAffineComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + bias_params = from_dlpack(component.BiasParams().to_dlpack()) + self.assertEqual(linear_params.shape, (120, 120)) + self.assertEqual(bias_params.shape, (120,)) + elif comp_name == 'tdnn1.affine': + self.assertEqual(comp_type, 'NaturalGradientAffineComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + bias_params = from_dlpack(component.BiasParams().to_dlpack()) + self.assertEqual(linear_params.shape, (16, 120)) + self.assertEqual(bias_params.shape, (16,)) + elif comp_name == 'tdnn1.batchnorm': + self.assertEqual(comp_type, 'BatchNormComponent') + component.SetTestMode(True) + mean = from_dlpack(component.Mean().to_dlpack()) + var = from_dlpack(component.Var().to_dlpack()) + self.assertEqual(mean.shape, (16,)) + self.assertEqual(var.shape, (16,)) + elif comp_name == 'tdnnf2.linear': + self.assertEqual(comp_type, 'TdnnComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + self.assertEqual(linear_params.shape, (2, 32)) + elif comp_name == 'tdnnf2.affine': + self.assertEqual(comp_type, 'TdnnComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + bias_params = from_dlpack(component.BiasParams().to_dlpack()) + self.assertEqual(linear_params.shape, (16, 4)) + self.assertEqual(bias_params.shape, (16,)) + elif comp_name == 'tdnnf2.batchnorm': + self.assertEqual(comp_type, 'BatchNormComponent') + component.SetTestMode(True) + mean = from_dlpack(component.Mean().to_dlpack()) + var = from_dlpack(component.Var().to_dlpack()) + self.assertEqual(mean.shape, (16,)) + self.assertEqual(var.shape, (16,)) + elif comp_name == 'prefinal-l': + self.assertEqual(comp_type, 'LinearComponent') + params = from_dlpack(component.Params().to_dlpack()) + self.assertEqual(params.shape, (4, 16)) + elif comp_name == 'prefinal-chain.affine': + self.assertEqual(comp_type, 'NaturalGradientAffineComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + bias_params = from_dlpack(component.BiasParams().to_dlpack()) + self.assertEqual(linear_params.shape, (16, 4)) + self.assertEqual(bias_params.shape, (16,)) + elif comp_name == 'prefinal-chain.batchnorm1': + self.assertEqual(comp_type, 'BatchNormComponent') + component.SetTestMode(True) + mean = from_dlpack(component.Mean().to_dlpack()) + var = from_dlpack(component.Var().to_dlpack()) + self.assertEqual(mean.shape, (16,)) + self.assertEqual(var.shape, (16,)) + elif comp_name == 'prefinal-chain.linear': + self.assertEqual(comp_type, 'LinearComponent') + params = from_dlpack(component.Params().to_dlpack()) + self.assertEqual(linear_params.shape, (16, 4)) + elif comp_name == 'prefinal-chain.batchnorm2': + self.assertEqual(comp_type, 'BatchNormComponent') + component.SetTestMode(True) + mean = from_dlpack(component.Mean().to_dlpack()) + var = from_dlpack(component.Var().to_dlpack()) + self.assertEqual(mean.shape, (4,)) + self.assertEqual(var.shape, (4,)) + elif comp_name == 'output.affine': + self.assertEqual(comp_type, 'NaturalGradientAffineComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + bias_params = from_dlpack(component.BiasParams().to_dlpack()) + self.assertEqual(linear_params.shape, (3448, 4)) + self.assertEqual(bias_params.shape, (3448,)) + elif comp_name == 'prefinal-xent.affine': + self.assertEqual(comp_type, 'NaturalGradientAffineComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + bias_params = from_dlpack(component.BiasParams().to_dlpack()) + self.assertEqual(linear_params.shape, (16, 4)) + self.assertEqual(bias_params.shape, (16,)) + elif comp_name == 'prefinal-xent.batchnorm1': + self.assertEqual(comp_type, 'BatchNormComponent') + component.SetTestMode(True) + mean = from_dlpack(component.Mean().to_dlpack()) + var = from_dlpack(component.Var().to_dlpack()) + self.assertEqual(mean.shape, (16,)) + self.assertEqual(var.shape, (16,)) + elif comp_name == 'prefinal-xent.linear': + self.assertEqual(comp_type, 'LinearComponent') + params = from_dlpack(component.Params().to_dlpack()) + self.assertEqual(linear_params.shape, (16, 4)) + elif comp_name == 'prefinal-xent.batchnorm2': + self.assertEqual(comp_type, 'BatchNormComponent') + component.SetTestMode(True) + mean = from_dlpack(component.Mean().to_dlpack()) + var = from_dlpack(component.Var().to_dlpack()) + self.assertEqual(mean.shape, (4,)) + self.assertEqual(var.shape, (4,)) + elif comp_name == 'output-xent.affine': + self.assertEqual(comp_type, 'NaturalGradientAffineComponent') + linear_params = from_dlpack( + component.LinearParams().to_dlpack()) + bias_params = from_dlpack(component.BiasParams().to_dlpack()) + self.assertEqual(linear_params.shape, (3448, 4)) + self.assertEqual(bias_params.shape, (3448,)) + else: + self.assertEqual(comp_type, 'LogSoftmaxComponent') + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/nnet3/nnet_normalize_component_pybind.cc b/src/pybind/nnet3/nnet_normalize_component_pybind.cc new file mode 100644 index 00000000000..ab5b4b79f3c --- /dev/null +++ b/src/pybind/nnet3/nnet_normalize_component_pybind.cc @@ -0,0 +1,37 @@ +// pybind/nnet3/nnet_normalize_component_pybind.cc + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_normalize_component_pybind.h" + +#include "nnet3/nnet-normalize-component.h" + +using namespace kaldi::nnet3; + +void pybind_nnet_normalize_component(py::module& m) { + using PyClass = kaldi::nnet3::BatchNormComponent; + py::class_(m, "BatchNormComponent") + .def("Mean", &PyClass::Mean) + .def("Var", &PyClass::Var) + .def("Count", &PyClass::Count) + .def("Eps", &PyClass::Eps) + .def("SetTestMode", &PyClass::SetTestMode, py::arg("test_mode")) + .def("Offset", &PyClass::Offset, py::return_value_policy::reference) + .def("Scale", overload_cast_<>()(&PyClass::Scale, py::const_), + py::return_value_policy::reference); +} diff --git a/src/pybind/nnet3/nnet_normalize_component_pybind.h b/src/pybind/nnet3/nnet_normalize_component_pybind.h new file mode 100644 index 00000000000..fd057cf6763 --- /dev/null +++ b/src/pybind/nnet3/nnet_normalize_component_pybind.h @@ -0,0 +1,25 @@ +// pybind/nnet3/nnet_normalize_component_pybind.h + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_NNET_NORMALIZE_COMPONENT_PYBIND_H_ +#define KALDI_PYBIND_NNET3_NNET_NORMALIZE_COMPONENT_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_normalize_component(py::module& m); + +#endif // KALDI_PYBIND_NNET3_NNET_NORMALIZE_COMPONENT_PYBIND_H_ diff --git a/src/pybind/nnet3/nnet_simple_component_pybind.cc b/src/pybind/nnet3/nnet_simple_component_pybind.cc new file mode 100644 index 00000000000..1a32bef3496 --- /dev/null +++ b/src/pybind/nnet3/nnet_simple_component_pybind.cc @@ -0,0 +1,47 @@ +// pybind/nnet3/nnet_simple_component_pybind.cc + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_simple_component_pybind.h" + +#include "nnet3/nnet-simple-component.h" + +using namespace kaldi::nnet3; + +void pybind_nnet_simple_component(py::module& m) { + using FAC = FixedAffineComponent; + py::class_(m, "FixedAffineComponent") + .def("LinearParams", &FAC::LinearParams, + py::return_value_policy::reference) + .def("BiasParams", &FAC::BiasParams, py::return_value_policy::reference); + + using LC = LinearComponent; + py::class_(m, "LinearComponent") + .def("Params", overload_cast_<>()(&LC::Params2, py::const_), + py::return_value_policy::reference); + + using AC = AffineComponent; + py::class_(m, "AffineComponent") + .def("LinearParams", overload_cast_<>()(&AC::LinearParams, py::const_), + py::return_value_policy::reference) + .def("BiasParams", overload_cast_<>()(&AC::BiasParams, py::const_), + py::return_value_policy::reference); + + using NGAC = NaturalGradientAffineComponent; + py::class_(m, "NaturalGradientAffineComponent"); +} diff --git a/src/pybind/nnet3/nnet_simple_component_pybind.h b/src/pybind/nnet3/nnet_simple_component_pybind.h new file mode 100644 index 00000000000..e11e4ed4c10 --- /dev/null +++ b/src/pybind/nnet3/nnet_simple_component_pybind.h @@ -0,0 +1,25 @@ +// pybind/nnet3/nnet_simple_component_pybind.h + +// Copyright 2020 JD AI, Beijing, China (author: Lu Fan) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_NNET_SIMPLE_COMPONENT_PYBIND_H_ +#define KALDI_PYBIND_NNET3_NNET_SIMPLE_COMPONENT_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_simple_component(py::module& m); + +#endif // KALDI_PYBIND_NNET3_NNET_SIMPLE_COMPONENT_PYBIND_H_