|
26 | 26 | import sys |
27 | 27 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) |
28 | 28 |
|
| 29 | +import numpy as np |
| 30 | + |
| 31 | +import kaldi_pybind |
29 | 32 | from kaldi_pybind.nnet3 import _SequentialNnetChainExampleReader |
30 | 33 | from kaldi_pybind.nnet3 import _RandomAccessNnetChainExampleReader |
31 | 34 | from kaldi_pybind.nnet3 import _NnetChainExampleWriter |
|
45 | 48 |
|
46 | 49 | from kaldi_pybind import _CompressedMatrixWriter |
47 | 50 |
|
| 51 | +from kaldi_pybind import _SequentialInt32VectorReader |
| 52 | +from kaldi_pybind import _RandomAccessInt32VectorReader |
| 53 | +from kaldi_pybind import _Int32VectorWriter |
| 54 | + |
48 | 55 | ################################################################################ |
49 | 56 | # Sequential Readers |
50 | 57 | ################################################################################ |
@@ -199,6 +206,12 @@ class SequentialVectorReader(_SequentialReaderBase, |
199 | 206 | pass |
200 | 207 |
|
201 | 208 |
|
| 209 | +class SequentialIntVectorReader(_SequentialReaderBase, |
| 210 | + _SequentialInt32VectorReader): |
| 211 | + '''Sequential table reader for integer sequences.''' |
| 212 | + pass |
| 213 | + |
| 214 | + |
202 | 215 | ################################################################################ |
203 | 216 | # Random Access Readers |
204 | 217 | ################################################################################ |
@@ -334,6 +347,12 @@ class RandomAccessVectorReader(_RandomAccessReaderBase, |
334 | 347 | pass |
335 | 348 |
|
336 | 349 |
|
| 350 | +class RandomAccessIntVectorReader(_RandomAccessReaderBase, |
| 351 | + _RandomAccessInt32VectorReader): |
| 352 | + '''Random access table reader for integer sequences.''' |
| 353 | + pass |
| 354 | + |
| 355 | + |
337 | 356 | ################################################################################ |
338 | 357 | # Writers |
339 | 358 | ################################################################################ |
@@ -431,19 +450,34 @@ class NnetChainExampleWriter(_WriterBase, _NnetChainExampleWriter): |
431 | 450 |
|
432 | 451 | class MatrixWriter(_WriterBase, _BaseFloatMatrixWriter): |
433 | 452 | '''Table writer for single precision matrices.''' |
434 | | - pass |
| 453 | + |
| 454 | + def Write(self, key, value): |
| 455 | + if isinstance(value, np.ndarray): |
| 456 | + m = kaldi_pybind.FloatSubMatrix(value) |
| 457 | + value = kaldi_pybind.FloatMatrix(m) |
| 458 | + super().Write(key, value) |
435 | 459 |
|
436 | 460 |
|
437 | 461 | class VectorWriter(_WriterBase, _BaseFloatVectorWriter): |
438 | 462 | '''Table writer for single precision vectors.''' |
439 | | - pass |
| 463 | + |
| 464 | + def Write(self, key, value): |
| 465 | + if isinstance(value, np.ndarray): |
| 466 | + v = kaldi_pybind.FloatSubVector(value) |
| 467 | + value = kaldi_pybind.FloatVector(v) |
| 468 | + super().Write(key, value) |
440 | 469 |
|
441 | 470 |
|
442 | 471 | class CompressedMatrixWriter(_WriterBase, _CompressedMatrixWriter): |
443 | 472 | '''Table writer for single precision compressed matrices.''' |
444 | 473 | pass |
445 | 474 |
|
446 | 475 |
|
| 476 | +class IntVectorWriter(_WriterBase, _Int32VectorWriter): |
| 477 | + '''Table writer for integer sequences.''' |
| 478 | + pass |
| 479 | + |
| 480 | + |
447 | 481 | if False: |
448 | 482 | # TODO(fangjun): enable the following once other wrappers are added |
449 | 483 |
|
@@ -526,11 +560,6 @@ class SequentialBoolReader(_SequentialReaderBase, |
526 | 560 | '''Sequential table reader for Booleans.''' |
527 | 561 | pass |
528 | 562 |
|
529 | | - class SequentialIntVectorReader(_SequentialReaderBase, |
530 | | - _kaldi_table.SequentialIntVectorReader): |
531 | | - '''Sequential table reader for integer sequences.''' |
532 | | - pass |
533 | | - |
534 | 563 | class SequentialIntVectorVectorReader( |
535 | 564 | _SequentialReaderBase, |
536 | 565 | _kaldi_table.SequentialIntVectorVectorReader): |
@@ -623,11 +652,6 @@ class RandomAccessBoolReader(_RandomAccessReaderBase, |
623 | 652 | '''Random access table reader for Booleans.''' |
624 | 653 | pass |
625 | 654 |
|
626 | | - class RandomAccessIntVectorReader(_RandomAccessReaderBase, |
627 | | - _kaldi_table.RandomAccessIntVectorReader): |
628 | | - '''Random access table reader for integer sequences.''' |
629 | | - pass |
630 | | - |
631 | 655 | class RandomAccessIntVectorVectorReader( |
632 | 656 | _RandomAccessReaderBase, |
633 | 657 | _kaldi_table.RandomAccessIntVectorVectorReader): |
@@ -886,10 +910,6 @@ class BoolWriter(_WriterBase, _kaldi_table.BoolWriter): |
886 | 910 | '''Table writer for Booleans.''' |
887 | 911 | pass |
888 | 912 |
|
889 | | - class IntVectorWriter(_WriterBase, _kaldi_table.IntVectorWriter): |
890 | | - '''Table writer for integer sequences.''' |
891 | | - pass |
892 | | - |
893 | 913 | class IntVectorVectorWriter(_WriterBase, |
894 | 914 | _kaldi_table.IntVectorVectorWriter): |
895 | 915 | '''Table writer for sequences of integer sequences.''' |
|
0 commit comments