Skip to content

Commit c83a8f4

Browse files
committed
temp fix unittests
1 parent 0e9a179 commit c83a8f4

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ def start(self) -> None:
103103
self.live.auto_refresh = True
104104
self.live._refresh_thread.start()
105105

106+
def stop(self) -> None:
107+
refresh_thread = self.live._refresh_thread
108+
self.live.auto_refresh = refresh_thread is not None
109+
super().stop()
110+
if refresh_thread:
111+
refresh_thread.stop()
112+
refresh_thread.join()
113+
106114
def refresh(self) -> None:
107115
if self.live.auto_refresh:
108116
self.live._refresh_thread.refresh_cond = True

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def test_rich_progress_bar_custom_theme():
131131
_, kwargs = mocks["ProcessingSpeedColumn"].call_args
132132
assert kwargs["style"] == theme.processing_speed
133133

134+
progress_bar.progress.live._refresh_thread.stop()
135+
134136

135137
@RunIf(rich=True)
136138
def test_rich_progress_bar_keyboard_interrupt(tmp_path):
@@ -176,6 +178,8 @@ def configure_columns(self, trainer):
176178
assert progress_bar.progress.columns[0] == custom_column
177179
assert len(progress_bar.progress.columns) == 2
178180

181+
progress_bar.progress.stop()
182+
179183

180184
@RunIf(rich=True)
181185
@pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 3)]))

0 commit comments

Comments
 (0)