Skip to content

Conversation

@fanlu
Copy link

@fanlu fanlu commented Mar 4, 2020

@csukuangfj I have test all calculation between kaldi tdnn_1c's model and pytorch kernel(2,2) model, the previous issue is talked in this pr #3940
After load kaldi model to pytorch, we can get similar result now.

exp kaldi test cer test wer dev cer dev wer
tdnn_1c 6.72 15.14 5.73 13.53
tdnn_1c load in pytorch BN eps:1e-3(default kaldi config) 6.76 15.20 5.75 13.54
tdnn_1c load in pytorch BN eps:1e-5(default pytorch config) 6.86 15.39 5.85 13.68

First, I use tool to do this load job, but it's too slow.
So I want to implement a pybind function to load kaldi's model. And this is more intuitive. What do you think?
But I have problem with GetComponent now, sorry about that I am not good at pybind11 and C++.
If you have time, please help me or give me some hint to do this job. Thanks


Thanks for @csukuangfj 's help, it is worked.

@danpovey
Copy link
Contributor

danpovey commented Mar 4, 2020

What is the problem with GetComponent()
You might want to start with existing model-evaluation code, such as in nnet3-compute.

@fanlu
Copy link
Author

fanlu commented Mar 4, 2020

Hi, Dan, I have tested nnet3-compute.cc completely.
The problem of GetComponent is that this function will return a point to Component.
I don't know how to implement it in pybind.
after add .def("GetComponent", &PyClass::GetComponent, py::arg("c")) in nnet_nnet_pybind.cc
I have got an error msg in make procedure. Maybe I must to read pybind document about this.

template argument deduction/substitution failed:
nnet3/nnet_nnet_pybind.cc:47:63: note:   mismatched types 'pybind11::detail::initimpl::pickle_factory<Args ...>' and 'const char [13]'
      .def("GetComponent", &PyClass::GetComponent, py::arg("c"))

@csukuangfj
Copy link
Contributor

That is because GetComponent is an overloadded method.
You have to specify which one you want to wrap.

@fanlu
Copy link
Author

fanlu commented Mar 4, 2020

should I Wrap all below Component? Is there an easiest way?

// static
Component* Component::NewComponentOfType(const std::string &component_type) {
  Component *ans = NULL;
  if (component_type == "SigmoidComponent") {
    ans = new SigmoidComponent();
  } else if (component_type == "TanhComponent") {
    ans = new TanhComponent();
  } else if (component_type == "SoftmaxComponent") {
    ans = new SoftmaxComponent();
  } else if (component_type == "LogSoftmaxComponent") {
    ans = new LogSoftmaxComponent();
  } else if (component_type == "RectifiedLinearComponent") {
    ans = new RectifiedLinearComponent();
  } else if (component_type == "NormalizeComponent") {
    ans = new NormalizeComponent();
  } else if (component_type == "PnormComponent") {
    ans = new PnormComponent();
  } else if (component_type == "AffineComponent") {
    ans = new AffineComponent();
  } else if (component_type == "LinearComponent") {
    ans = new LinearComponent();
  } else if (component_type == "NaturalGradientAffineComponent") {
    ans = new NaturalGradientAffineComponent();
  } else if (component_type == "PerElementScaleComponent") {
    ans = new PerElementScaleComponent();
  } else if (component_type == "NaturalGradientPerElementScaleComponent") {
    ans = new NaturalGradientPerElementScaleComponent();
  } else if (component_type == "PerElementOffsetComponent") {
    ans = new PerElementOffsetComponent();
  } else if (component_type == "SumGroupComponent") {
    ans = new SumGroupComponent();
  } else if (component_type == "FixedAffineComponent") {
    ans = new FixedAffineComponent();
  } else if (component_type == "FixedScaleComponent") {
    ans = new FixedScaleComponent();
  } else if (component_type == "FixedBiasComponent") {
    ans = new FixedBiasComponent();
  } else if (component_type == "NoOpComponent") {
    ans = new NoOpComponent();
  } else if (component_type == "ClipGradientComponent") {
    ans = new ClipGradientComponent();
  } else if (component_type == "ElementwiseProductComponent") {
    ans = new ElementwiseProductComponent();
  } else if (component_type == "ConvolutionComponent") {
    ans = new ConvolutionComponent();
  } else if (component_type == "TdnnComponent") {
    ans = new TdnnComponent();
  } else if (component_type == "MaxpoolingComponent") {
    ans = new MaxpoolingComponent();
  } else if (component_type == "PermuteComponent") {
    ans = new PermuteComponent();
  } else if (component_type == "DistributeComponent") {
    ans = new DistributeComponent();
  } else if (component_type == "CompositeComponent") {
    ans = new CompositeComponent();
  } else if (component_type == "RepeatedAffineComponent") {
    ans = new RepeatedAffineComponent();
  } else if (component_type == "BlockAffineComponent") {
    ans = new BlockAffineComponent();
  } else if (component_type == "NaturalGradientRepeatedAffineComponent") {
    ans = new NaturalGradientRepeatedAffineComponent();
  } else if (component_type == "StatisticsExtractionComponent") {
    ans = new StatisticsExtractionComponent();
  } else if (component_type == "StatisticsPoolingComponent") {
    ans = new StatisticsPoolingComponent();
  } else if (component_type == "ConstantFunctionComponent") {
    ans = new ConstantFunctionComponent();
  } else if (component_type == "ConstantComponent") {
    ans = new ConstantComponent();
  } else if (component_type == "DropoutComponent") {
    ans = new DropoutComponent();
  } else if (component_type == "DropoutMaskComponent") {
    ans = new DropoutMaskComponent();
  } else if (component_type == "GeneralDropoutComponent") {
    ans = new GeneralDropoutComponent();
  } else if (component_type == "SpecAugmentTimeMaskComponent") {
    ans = new SpecAugmentTimeMaskComponent();
  } else if (component_type == "BackpropTruncationComponent") {
    ans = new BackpropTruncationComponent();
  } else if (component_type == "LstmNonlinearityComponent") {
    ans = new LstmNonlinearityComponent();
  } else if (component_type == "BatchNormComponent") {
    ans = new BatchNormComponent();
  } else if (component_type == "TimeHeightConvolutionComponent") {
    ans = new TimeHeightConvolutionComponent();
  } else if (component_type == "RestrictedAttentionComponent") {
    ans = new RestrictedAttentionComponent();
  } else if (component_type == "SumBlockComponent") {
    ans = new SumBlockComponent();
  } else if (component_type == "GruNonlinearityComponent") {
    ans = new GruNonlinearityComponent();
  } else if (component_type == "OutputGruNonlinearityComponent") {
    ans = new OutputGruNonlinearityComponent();
  } else if (component_type == "ScaleAndOffsetComponent") {
    ans = new ScaleAndOffsetComponent();
  }
  if (ans != NULL) {
    KALDI_ASSERT(component_type == ans->Type());
  }
  return ans;
}

@csukuangfj
Copy link
Contributor

csukuangfj commented Mar 4, 2020

Please refer to

(PyClass(*)(const PyClass&, const std::vector<std::vector<FloatType>>&))(

for how to wrap overloaded methods.

@csukuangfj
Copy link
Contributor

csukuangfj commented Mar 4, 2020

It is needed if you want to build a nnet3 model from a config file.

Note that it is a static method.

@csukuangfj
Copy link
Contributor

Please refer to

.def_static("Zero", &PyClass::Zero)

for how to wrap static methods.

Note that in your case, the function returns a pointer. You have to use the flag
take_ownership.

@danpovey
Copy link
Contributor

danpovey commented Mar 4, 2020 via email

@csukuangfj
Copy link
Contributor

.def_static("NewComponentOfType", &PyClass:: NewComponentOfType, py::return_value_policy::take_ownership) 

will do the job.

You do NOT need to wrap other components to wrap this function.

@csukuangfj
Copy link
Contributor

If you have time, please write a nnet_nnet_pybind_test.py
to test your code.

@fanlu
Copy link
Author

fanlu commented Mar 4, 2020

Should I wrap init a nnet3 model from a config file? I want a nnet3 model just readed from a binary model file.
And after follow your suggestion, I wraped GetComponent like

.def("GetComponent", (kaldi::nnet3::Component(*)(int32 c))(&PyClass::GetComponent), py::arg("c"), py::return_value_policy::take_ownership)

and this give me an error msg

nnet3/nnet_nnet_pybind.cc:48:87: error: address of overloaded function with no contextual type information
      .def("GetComponent", (kaldi::nnet3::Component(*)(int32 c))(&PyClass::GetComponent), py::arg("c"), py::return_value_policy::take_ownership)

@csukuangfj
Copy link
Contributor

csukuangfj commented Mar 4, 2020

.def("GetComponent", (kaldi::nnet3::Component(*)(int32 c))(&PyClass::GetComponent), py::arg("c"), py::return_value_policy::take_ownership)

what is the function type that you're trying to convert to?

@csukuangfj
Copy link
Contributor

(ret_type (class_name::*)(arg_list) [optional_const_specifier])

The example I gave is for a global function. In your case, it is a class member function.

@csukuangfj
Copy link
Contributor

&PyClass::GetComponent returns a pointer.
(type_converted_to) (&PyClass::GetComponent), you have to figure out the appropriate type_converted_to

@fanlu
Copy link
Author

fanlu commented Mar 4, 2020

Note that what kaldi-onnx dose. And I want to get a dict from kaldi's model file like the code below.

with open(filename, 'r') as f:
        p = Nnet3Parser(f)
        p.run()
        for component in p._components:
            kaldi_dict[component["name"]] = component

So I suspect that I can use GetComponent to get Component and get children Component's params after range(NumComponents)

@fanlu
Copy link
Author

fanlu commented Mar 4, 2020

Can type_converted_to to be Component?

&PyClass::GetComponent returns a pointer.
(type_converted_to) (&PyClass::GetComponent), you have to figure out the appropriate type_converted_to

@csukuangfj
Copy link
Contributor

As you can see from the following

kaldi/src/nnet3/nnet-nnet.h

Lines 128 to 133 in cc63cca

/// Return component indexed c. Not a copy; not owned by caller.
Component *GetComponent(int32 c);
/// Return component indexed c (const version). Not a copy; not owned by
/// caller.
const Component *GetComponent(int32 c) const;

GetComponent is an overloaded method. You have to specify explicitly which method
is going to be wrapped; otherwise, there is no way for pybind11 to infer the prototype of
GetComponent.

You can use either

.def("GetComponent", (Component* (PyClass::*)(int32) )&PyClass::GetComponent, py::arg("c"),py::return_value_policy::reference)

or

.def("GetComponent", (const Component* (PyClass::*)(int32) const )&PyClass::GetComponent, py::arg("c"), py::arg("c"),py::return_value_policy::reference)

Or you can use both.

Since the return type is Component*, you have to wrap it first.

class Component {

@fanlu
Copy link
Author

fanlu commented Mar 4, 2020

I have got the Component Object. As Dan's advice, Is there an convenient way to cast Component to Child class If I do not wrap all Child Class?
image

FixedAffineComponent

  const CuMatrix<BaseFloat> &LinearParams() const { return linear_params_; }
  const CuVector<BaseFloat> &BiasParams() const { return bias_params_; }

@csukuangfj
Copy link
Contributor

@fanlu
You can wrap any subclasses of Component that you are using.

@francisr
Copy link
Contributor

francisr commented Mar 4, 2020

First, I use tool to do this load job, but it's too slow.
What's slow about the tool? The conversion to ONNX or the runtime?

@fanlu
Copy link
Author

fanlu commented Mar 4, 2020

Nnet3Parser(f).run()

First, I use tool to do this load job, but it's too slow.
What's slow about the tool? The conversion to ONNX or the runtime?

@csukuangfj
Copy link
Contributor

Is the conversion process slow?
Or after the conversion, is the runtime slow?

@csukuangfj
Copy link
Contributor

You can split Nnet3Parser(f).run() into

start = start_time
model = Nnet3Parser(f)
stop = end_time
--> the time for the conversion process

start =  start_time
model.run()
stop = stop_time
---> time for the runtime.

@fanlu
Copy link
Author

fanlu commented Mar 5, 2020

the time of conversion and runtime are 0.0004012584686279297s and 97.50831198692322s, respectively

@csukuangfj
Copy link
Contributor

I have some problem with from_dlpack() now.

Please refer to the above comment.

@csukuangfj
Copy link
Contributor

Please mark a comment as resolved when it is done.

@fanlu
Copy link
Author

fanlu commented Mar 6, 2020

because in class TdnnComponent the return type of get linear_params_ method is CuMatrixBase

CuMatrixBase<BaseFloat> &LinearParams() { return linear_params_; }

should I implement a new method like

CuMatrix<BaseFloat> &LinearParams() { return linear_params_; }

@fanlu

I have changed to_dlpack() from CuMatrix to CuMatrixBase also, I am not sure it is a right way.

Please refer to Dan's comment (#3788 (comment)):

I think it might be worth having a to_dlpack specifically in class Matrix, rather than in MatrixBase. Class Matrix owns its memory, so if the deleter were to do a Py_DECREF on the Python object where that Matrix lives, the memory would be managed appropriately. I'm concerned that with the current approach, it's just a segmentation fault waiting to happen because it doesn't automatically ensure that the pointer is still valid/allocated.

@csukuangfj
Copy link
Contributor

@fanlu
The underlying type is

CuMatrix<BaseFloat> linear_params_;

Please show me the return type in Python using print() and type().
I guess pybind11 is smart enough to infer the return type and will NOT give you an object of type CuMatrixBase.

@fanlu
Copy link
Author

fanlu commented Mar 6, 2020

print(component.LinearParams())
<kaldi_pybind.FloatCuMatrixBase object at 0x7f2890160ef0>

print(type(component.LinearParams()))
<class 'kaldi_pybind.FloatCuMatrixBase'>

@csukuangfj
Copy link
Contributor

I think it is a mistake in the c++ code to use CuMatrixBase:

CuMatrixBase<BaseFloat> &LinearParams() { return linear_params_; }
// This allows you to resize the vector in order to add a bias where
// there previously was none-- obviously this should be done carefully.
CuVector<BaseFloat> &BiasParams() { return bias_params_; }

@danpovey
Can we replace it with CuMatrix?

@fanlu
Copy link
Author

fanlu commented Mar 6, 2020

the same style in class LinearComponent

CuMatrixBase<BaseFloat> &Params() { return params_; }
const CuMatrixBase<BaseFloat> &Params() const { return params_; }
CuMatrix<BaseFloat> params_;

@csukuangfj
Copy link
Contributor

print(component.LinearParams())
<kaldi_pybind.FloatCuMatrixBase object at 0x7f2890160ef0>

print(type(component.LinearParams()))
<class 'kaldi_pybind.FloatCuMatrixBase'>

What is the type of component ??? @fanlu

@fanlu
Copy link
Author

fanlu commented Mar 6, 2020

print(component)
<kaldi_pybind.nnet3.TdnnComponent object at 0x7f1b17fff9f0>

print(type(component))
<class 'kaldi_pybind.nnet3.TdnnComponent'>

@csukuangfj
Copy link
Contributor

csukuangfj commented Mar 6, 2020

For the following program

class Base {
 public:
  virtual std::string Hello() const = 0;

  virtual ~Base() = default;

 protected:
  Base() = default;
};

class Child : public Base {
 public:
  Child() = default;
  std::string Hello() const override { return "hello"; }
};

class Test {
 public:
  Test() = default;
  const Base& GetChild() const { return c_; }

 private:
  Child c_;
};

The return type of GetChild() in Python is Child.


Please check your code again.

It should return a type of FloatCuMatrix

@fanlu fanlu changed the title [WIP] support load kaldi model in python support load kaldi model in python Mar 6, 2020
@csukuangfj
Copy link
Contributor

+1

@danpovey danpovey merged commit 61cda6b into kaldi-asr:pybind11 Mar 13, 2020
megazone87 pushed a commit to megazone87/kaldi that referenced this pull request Mar 16, 2020
* support load kaldi model in python

* add some component

* split one file to multi component wrap files

* fix some bugs and add test mdl

* add testmode func in batchnorm pybind

* change StatsSum StatsSumsq to Mean Var

* make const
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants