Skip to content

Commit 12d48d8

Browse files
csukuangfjnaxingyu
authored andcommitted
[pybind] add pybind wrapper for int vector reader/writer. (#3833)
1 parent 958bf0e commit 12d48d8

File tree

10 files changed

+219
-970
lines changed

10 files changed

+219
-970
lines changed

src/pybind/Makefile

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ LD_PRELOAD += $(MKL_LD_PRELOAD)
106106
export LD_PRELOAD
107107
endif
108108

109+
# directories should be sorted in alphabetic order
110+
TEST_DIRS := \
111+
chain \
112+
cudamatrix \
113+
dlpack \
114+
feat \
115+
fst \
116+
matrix \
117+
nnet3 \
118+
tests
119+
109120
.PHONY: all clean test
110121

111122
all: $(LIBFILE)
@@ -124,15 +135,7 @@ clean:
124135
-rm -f .depend.mk
125136

126137
test: all
127-
python3 tests/test_dlpack_subvector.py
128-
python3 tests/test_kaldi_pybind.py
129-
make -C chain test
130-
make -C cudamatrix test
131-
make -C dlpack test
132-
make -C feat test
133-
make -C fst test
134-
make -C matrix test
135-
make -C nnet3 test
138+
for d in $(TEST_DIRS); do make -C $$d test; done
136139

137140
# valgrind-python.supp is from http://svn.python.org/projects/python/trunk/Misc/valgrind-python.supp
138141
# since we do not compile Python from source, we follow the comment in valgrind-python.supp
@@ -145,7 +148,7 @@ depend:
145148
rm -f .depend.mk
146149
for f in $(CCFILES); do \
147150
$(CXX) -M -MT "$$(dirname $$f)/$$(basename -s .cc $$f).o" \
148-
$(CXXFLAGS) $$f >> .depend.mk; \
151+
$(CXXFLAGS) $$f >> .depend.mk; \
149152
done
150153

151154
-include .depend.mk

src/pybind/fst/vector_fst_pybind_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def test_std_vector_fst(self):
8686
print('read back compiled fst is:')
8787
print(read_back_compiled_fst)
8888

89+
os.remove(compiled_filename)
90+
os.remove(fst_filename)
91+
8992

9093
if __name__ == '__main__':
9194
unittest.main()

src/pybind/kaldi/__init__.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,6 @@
55

66
from kaldi_pybind import *
77

8+
from pytorch_util import *
89
from symbol_table import *
9-
from pytorch_util import PytorchToCuSubMatrix
10-
from pytorch_util import PytorchToCuSubVector
11-
from pytorch_util import PytorchToSubMatrix
12-
from pytorch_util import PytorchToSubVector
13-
14-
from table import SequentialNnetChainExampleReader
15-
from table import RandomAccessNnetChainExampleReader
16-
from table import NnetChainExampleWriter
17-
18-
from table import SequentialWaveReader
19-
from table import RandomAccessWaveReader
20-
21-
from table import SequentialWaveInfoReader
22-
from table import RandomAccessWaveInfoReader
23-
24-
from table import SequentialMatrixReader
25-
from table import RandomAccessMatrixReader
26-
from table import MatrixWriter
27-
28-
from table import SequentialVectorReader
29-
from table import RandomAccessVectorReader
30-
from table import VectorWriter
31-
32-
from table import CompressedMatrixWriter
10+
from table import *

src/pybind/kaldi/table.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import sys
2727
sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir))
2828

29+
import numpy as np
30+
31+
import kaldi_pybind
2932
from kaldi_pybind.nnet3 import _SequentialNnetChainExampleReader
3033
from kaldi_pybind.nnet3 import _RandomAccessNnetChainExampleReader
3134
from kaldi_pybind.nnet3 import _NnetChainExampleWriter
@@ -45,6 +48,10 @@
4548

4649
from kaldi_pybind import _CompressedMatrixWriter
4750

51+
from kaldi_pybind import _SequentialInt32VectorReader
52+
from kaldi_pybind import _RandomAccessInt32VectorReader
53+
from kaldi_pybind import _Int32VectorWriter
54+
4855
################################################################################
4956
# Sequential Readers
5057
################################################################################
@@ -199,6 +206,12 @@ class SequentialVectorReader(_SequentialReaderBase,
199206
pass
200207

201208

209+
class SequentialIntVectorReader(_SequentialReaderBase,
210+
_SequentialInt32VectorReader):
211+
'''Sequential table reader for integer sequences.'''
212+
pass
213+
214+
202215
################################################################################
203216
# Random Access Readers
204217
################################################################################
@@ -334,6 +347,12 @@ class RandomAccessVectorReader(_RandomAccessReaderBase,
334347
pass
335348

336349

350+
class RandomAccessIntVectorReader(_RandomAccessReaderBase,
351+
_RandomAccessInt32VectorReader):
352+
'''Random access table reader for integer sequences.'''
353+
pass
354+
355+
337356
################################################################################
338357
# Writers
339358
################################################################################
@@ -431,19 +450,34 @@ class NnetChainExampleWriter(_WriterBase, _NnetChainExampleWriter):
431450

432451
class MatrixWriter(_WriterBase, _BaseFloatMatrixWriter):
433452
'''Table writer for single precision matrices.'''
434-
pass
453+
454+
def Write(self, key, value):
455+
if isinstance(value, np.ndarray):
456+
m = kaldi_pybind.FloatSubMatrix(value)
457+
value = kaldi_pybind.FloatMatrix(m)
458+
super().Write(key, value)
435459

436460

437461
class VectorWriter(_WriterBase, _BaseFloatVectorWriter):
438462
'''Table writer for single precision vectors.'''
439-
pass
463+
464+
def Write(self, key, value):
465+
if isinstance(value, np.ndarray):
466+
v = kaldi_pybind.FloatSubVector(value)
467+
value = kaldi_pybind.FloatVector(v)
468+
super().Write(key, value)
440469

441470

442471
class CompressedMatrixWriter(_WriterBase, _CompressedMatrixWriter):
443472
'''Table writer for single precision compressed matrices.'''
444473
pass
445474

446475

476+
class IntVectorWriter(_WriterBase, _Int32VectorWriter):
477+
'''Table writer for integer sequences.'''
478+
pass
479+
480+
447481
if False:
448482
# TODO(fangjun): enable the following once other wrappers are added
449483

@@ -526,11 +560,6 @@ class SequentialBoolReader(_SequentialReaderBase,
526560
'''Sequential table reader for Booleans.'''
527561
pass
528562

529-
class SequentialIntVectorReader(_SequentialReaderBase,
530-
_kaldi_table.SequentialIntVectorReader):
531-
'''Sequential table reader for integer sequences.'''
532-
pass
533-
534563
class SequentialIntVectorVectorReader(
535564
_SequentialReaderBase,
536565
_kaldi_table.SequentialIntVectorVectorReader):
@@ -623,11 +652,6 @@ class RandomAccessBoolReader(_RandomAccessReaderBase,
623652
'''Random access table reader for Booleans.'''
624653
pass
625654

626-
class RandomAccessIntVectorReader(_RandomAccessReaderBase,
627-
_kaldi_table.RandomAccessIntVectorReader):
628-
'''Random access table reader for integer sequences.'''
629-
pass
630-
631655
class RandomAccessIntVectorVectorReader(
632656
_RandomAccessReaderBase,
633657
_kaldi_table.RandomAccessIntVectorVectorReader):
@@ -886,10 +910,6 @@ class BoolWriter(_WriterBase, _kaldi_table.BoolWriter):
886910
'''Table writer for Booleans.'''
887911
pass
888912

889-
class IntVectorWriter(_WriterBase, _kaldi_table.IntVectorWriter):
890-
'''Table writer for integer sequences.'''
891-
pass
892-
893913
class IntVectorVectorWriter(_WriterBase,
894914
_kaldi_table.IntVectorVectorWriter):
895915
'''Table writer for sequences of integer sequences.'''

src/pybind/matrix/kaldi_vector_pybind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ void pybind_kaldi_vector(py::module& m) {
6565
})
6666
.def(py::init<const MatrixIndexT, MatrixResizeType>(), py::arg("size"),
6767
py::arg("resize_type") = kSetZero)
68+
.def(py::init<const VectorBase<float>&>(), py::arg("v"),
69+
"Copy-constructor from base-class, needed to copy from SubVector.")
6870
.def("to_dlpack", [](py::object obj) { return VectorToDLPack(&obj); });
6971

7072
py::class_<SubVector<float>, VectorBase<float>>(m, "FloatSubVector")

src/pybind/tests/Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
test:
3+
python3 ./test_dlpack_subvector.py
4+
python3 ./test_kaldi_pybind.py
5+
python3 ./test_table_types.py

src/pybind/tests/test_kaldi_pybind.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,26 @@ def test_float_vector(self):
2525
self.assertTrue((kp_vector == gold).all())
2626

2727
def test_float_matrix(self):
28-
return
2928
# test FloatMatrix
3029
kp_matrix = kaldi.FloatMatrix(4, 5)
3130

32-
np_matrix = kp_matrix.numpy()
33-
34-
np_matrix[2][3] = 2.0
31+
kp_matrix[2, 3] = 2.0
3532

3633
gold = np.array([
3734
[0, 0, 0, 0, 0],
3835
[0, 0, 0, 0, 0],
3936
[0, 0, 0, 2, 0],
4037
[0, 0, 0, 0, 0],
4138
])
42-
self.assertTrue((kp_matrix == gold).all())
39+
np.testing.assert_array_equal(kp_matrix.numpy(), gold)
4340

4441
def test_matrix_reader_writer(self):
4542
kp_matrix = kaldi.FloatMatrix(2, 3)
4643
wspecifier = 'ark,t:test.ark'
4744
rspecifier = 'ark:test.ark'
4845
matrix_writer = kaldi.MatrixWriter(wspecifier)
4946

50-
np_matrix = kp_matrix.numpy()
51-
np_matrix[0, 0] = 10
47+
kp_matrix[0, 0] = 10
5248

5349
matrix_writer.Write('id_1', kp_matrix)
5450
matrix_writer.Close()
@@ -59,7 +55,10 @@ def test_matrix_reader_writer(self):
5955

6056
value = matrix_reader.Value()
6157
gold = np.array([[10, 0, 0], [0, 0, 0]])
62-
self.assertTrue((np.array(value, copy=False) == gold).all())
58+
np.testing.assert_array_equal(value.numpy(), gold)
59+
60+
matrix_reader.Close()
61+
os.remove('test.ark')
6362

6463
def test_matrix_reader_iterator(self):
6564
kp_matrix = kaldi.FloatMatrix(2, 3)
@@ -71,11 +70,13 @@ def test_matrix_reader_iterator(self):
7170

7271
gold_key_list = ['id_1']
7372
gold_value_list = [np.array([[0, 0, 0], [0, 0, 0]])]
74-
for (key, value), gold_key, gold_value in zip(
75-
kaldi.SequentialMatrixReader(rspecifier), gold_key_list,
76-
gold_value_list):
73+
reader = kaldi.SequentialMatrixReader(rspecifier)
74+
for (key, value), gold_key, gold_value in zip(reader, gold_key_list,
75+
gold_value_list):
7776
self.assertEqual(key, gold_key)
78-
self.assertTrue((value == gold_value).all())
77+
np.testing.assert_array_equal(value.numpy(), gold_value)
78+
reader.Close()
79+
os.remove('test.ark')
7980

8081
def test_matrix_random_access_reader(self):
8182
kp_matrix = kaldi.FloatMatrix(2, 3)
@@ -88,8 +89,11 @@ def test_matrix_random_access_reader(self):
8889
reader = kaldi.RandomAccessMatrixReader(rspecifier)
8990
gold = np.array([[0, 0, 0], [0, 0, 0]])
9091
self.assertTrue('id_1' in reader)
91-
self.assertTrue((np.array(reader['id_1']) == gold).all())
92+
93+
np.testing.assert_array_equal(reader['id_1'].numpy(), gold)
9294
self.assertFalse('id_2' in reader)
95+
reader.Close()
96+
os.remove('test.ark')
9397

9498

9599
if __name__ == '__main__':

0 commit comments

Comments
 (0)