Skip to content

Commit 7376b86

Browse files
authored
Merge pull request #23 from ChiSym/shutdown
add async cleanup method to trie
2 parents 84cd7dc + d6a6eab commit 7376b86

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

genlm_backend/trie/async_impl.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,24 @@ async def _background_loop(self):
133133
future.set_exception(e)
134134
raise
135135

136+
async def cleanup(self):
137+
"""Async cleanup - preferred method"""
138+
if self._task and not self._task.done():
139+
self._task.cancel()
140+
try:
141+
await self._task
142+
except asyncio.CancelledError:
143+
pass
144+
self._task = None
145+
136146
def shutdown(self):
137147
"""Stop the background processing task and cleanup resources."""
138-
if self._task:
139-
self._task.cancel()
148+
if self._task is not None:
149+
try:
150+
self._task.cancel()
151+
except RuntimeError:
152+
# Ignore runtime errors that might occur if event loop is closed
153+
pass
140154
self._task = None
141155

142156
def __del__(self):

tests/test_trie.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,16 @@ async def test_async_trie(mock_llm, backend):
174174
np.testing.assert_allclose(have, want, rtol=1e-5, atol=1e-8)
175175

176176

177+
@pytest.mark.asyncio
178+
@pytest.mark.parametrize("backend", ["sequential", "parallel"])
179+
async def test_async_trie_cleanup(mock_llm, backend):
180+
async_trie = AsyncTokenCharacterTrie.from_vocab(
181+
mock_llm.byte_vocab, backend=backend
182+
)
183+
await async_trie.cleanup()
184+
assert async_trie._task is None
185+
186+
177187
def test_sequential_preprocessing(decode):
178188
trie = TokenCharacterTrie(decode=decode)
179189

0 commit comments

Comments
 (0)