Skip to content

Commit fa7386b

Browse files
fanlucsukuangfj
authored andcommitted
support read xvector egs from python (#3850)
* support read xvector egs from python * add some document for NnetExample
1 parent 1bd7980 commit fa7386b

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

src/pybind/kaldi/table.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
from kaldi_pybind.nnet3 import _RandomAccessNnetChainExampleReader
3434
from kaldi_pybind.nnet3 import _NnetChainExampleWriter
3535

36+
from kaldi_pybind.nnet3 import _SequentialNnetExampleReader
37+
from kaldi_pybind.nnet3 import _RandomAccessNnetExampleReader
38+
3639
from kaldi_pybind.feat import _SequentialWaveReader
3740
from kaldi_pybind.feat import _RandomAccessWaveReader
3841
from kaldi_pybind.feat import _SequentialWaveInfoReader
@@ -194,6 +197,11 @@ class SequentialNnetChainExampleReader(_SequentialReaderBase,
194197
'''Sequential table reader for nnet chain examples.'''
195198
pass
196199

200+
class SequentialNnetExampleReader(_SequentialReaderBase,
201+
_SequentialNnetExampleReader):
202+
'''Sequential table reader for nnet examples.'''
203+
pass
204+
197205

198206
class SequentialWaveReader(_SequentialReaderBase, _SequentialWaveReader):
199207
'''Sequential table reader for wave files.'''
@@ -350,6 +358,10 @@ class RandomAccessNnetChainExampleReader(_RandomAccessReaderBase,
350358
'''Random access table reader for nnet chain examples.'''
351359
pass
352360

361+
class RandomAccessNnetExampleReader(_RandomAccessReaderBase,
362+
_RandomAccessNnetExampleReader):
363+
'''Random access table reader for nnet examples.'''
364+
pass
353365

354366
class RandomAccessWaveReader(_RandomAccessReaderBase, _RandomAccessWaveReader):
355367
'''Random access table reader for wave files.'''
@@ -572,11 +584,6 @@ class SequentialKwsIndexFstReader(
572584
'''Sequential table reader for FSTs over the KWS index semiring.'''
573585
pass
574586

575-
class SequentialNnetExampleReader(_SequentialReaderBase,
576-
_kaldi_table.SequentialNnetExampleReader):
577-
'''Sequential table reader for nnet examples.'''
578-
pass
579-
580587
class SequentialRnnlmExampleReader(_SequentialReaderBase,
581588
_kaldi_table.SequentialRnnlmExampleReader
582589
):
@@ -658,12 +665,6 @@ class RandomAccessKwsIndexFstReader(
658665
'''Random access table reader for FSTs over the KWS index semiring.'''
659666
pass
660667

661-
class RandomAccessNnetExampleReader(
662-
_RandomAccessReaderBase,
663-
_kaldi_table.RandomAccessNnetExampleReader):
664-
'''Random access table reader for nnet examples.'''
665-
pass
666-
667668
class RandomAccessIntReader(_RandomAccessReaderBase,
668669
_kaldi_table.RandomAccessIntReader):
669670
'''Random access table reader for integers.'''

src/pybind/nnet3/nnet_example_pybind.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
// Copyright 2019 Mobvoi AI Lab, Beijing, China
44
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)
5+
// 2020 JD AI, Beijing, China
6+
// (author: Lu Fan)
57

68
// See ../../../COPYING for clarification regarding multiple authors
79
//
@@ -21,6 +23,7 @@
2123
#include "nnet3/nnet_example_pybind.h"
2224

2325
#include "nnet3/nnet-example.h"
26+
#include "util/kaldi_table_pybind.h"
2427

2528
using namespace kaldi;
2629
using namespace kaldi::nnet3;
@@ -40,4 +43,25 @@ void pybind_nnet_example(py::module& m) {
4043
"SparseMatrix would be the natural format for posteriors).");
4144
// TODO(fangjun): other constructors, fields and methods can be wrapped when
4245
}
46+
{
47+
using PyClass = NnetExample;
48+
py::class_<PyClass>(m, "NnetExample",
49+
"NnetExample is the input data and corresponding label (or labels) for one or "
50+
"more frames of input, used for standard cross-entropy training of neural "
51+
"nets (and possibly for other objective functions). ")
52+
.def(py::init<>())
53+
.def_readwrite("io", &PyClass::io,
54+
"\"io\" contains the input and output. In principle there can be multiple "
55+
"types of both input and output, with different names. The order is "
56+
"irrelevant.")
57+
.def("Compress", &PyClass::Compress,
58+
"Compresses any (input) features that are not sparse.")
59+
.def("Read", &PyClass::Read, py::arg("is"), py::arg("binary"));
60+
61+
pybind_sequential_table_reader<KaldiObjectHolder<PyClass>>(
62+
m, "_SequentialNnetExampleReader");
63+
64+
pybind_random_access_table_reader<KaldiObjectHolder<PyClass>>(
65+
m, "_RandomAccessNnetExampleReader");
66+
}
4367
}

0 commit comments

Comments
 (0)