Skip to content

Commit 1bd7980

Browse files
qindazhucsukuangfj
authored andcommitted
add context manager support for IO (#3846)
* add context manager support for IO add context manager support for IO * update to use single quotes for consistency
1 parent 04f9378 commit 1bd7980

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

src/pybind/kaldi/table.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def __init__(self, rspecifier=''):
9191
def __enter__(self):
9292
return self
9393

94+
def __exit__(self, type, value, traceback):
95+
if self.IsOpen():
96+
self.Close()
97+
9498
def __iter__(self):
9599
while not self.Done():
96100
key = self.Key()
@@ -262,6 +266,10 @@ def __init__(self, rspecifier=''):
262266
def __enter__(self):
263267
return self
264268

269+
def __exit__(self, type, value, traceback):
270+
if self.IsOpen():
271+
self.Close()
272+
265273
def __contains__(self, key):
266274
return self.HasKey(key)
267275

@@ -417,6 +425,10 @@ def __init__(self, wspecifier=''):
417425
def __enter__(self):
418426
return self
419427

428+
def __exit__(self, type, value, traceback):
429+
if self.IsOpen():
430+
self.Close()
431+
420432
def __setitem__(self, key, value):
421433
self.Write(key, value)
422434

src/pybind/tests/test_kaldi_pybind.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ def test_matrix_reader_writer(self):
5858
np.testing.assert_array_equal(value.numpy(), gold)
5959

6060
matrix_reader.Close()
61+
62+
# test with context manager
63+
kp_matrix[0, 0] = 20
64+
with kaldi.MatrixWriter(wspecifier) as writer:
65+
writer.Write('id_2', kp_matrix)
66+
with kaldi.SequentialMatrixReader(rspecifier) as reader:
67+
key = reader.Key()
68+
self.assertEqual(key, 'id_2')
69+
value = reader.Value()
70+
gold = np.array([[20, 0, 0], [0, 0, 0]])
71+
np.testing.assert_array_equal(value.numpy(), gold)
72+
6173
os.remove('test.ark')
6274

6375
def test_matrix_reader_iterator(self):

src/pybind/tests/test_table_types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ def test_float_matrix(self):
124124
np.testing.assert_array_almost_equal(
125125
reader.Value(key).numpy(), data[key])
126126
reader.Close()
127+
128+
# test RandomAccessReader with context manager
129+
with kaldi.RandomAccessMatrixReader(rspecifier) as reader:
130+
for key in data.keys():
131+
self.assertTrue(reader.HasKey(key))
132+
np.testing.assert_array_almost_equal(
133+
reader.Value(key).numpy(), data[key])
127134

128135
shutil.rmtree(tmp)
129136

0 commit comments

Comments
 (0)