Skip to content

Commit 006fa4b

Browse files
authored
fix(skore-hub-project/metrics): Raise exception from thread (#2296)
Since metrics compute is multi-threaded, ensure that any exceptions thrown in a sub-thread are also thrown in the main thread.
1 parent 730a815 commit 006fa4b

File tree

3 files changed

+40
-8
lines changed

3 files changed

+40
-8
lines changed

skore-hub-project/src/skore_hub_project/report/report.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
from abc import ABC
6-
from collections import deque
76
from concurrent.futures import ThreadPoolExecutor, as_completed
87
from functools import cached_property, partial
98
from typing import ClassVar, Generic, TypeVar, cast
@@ -119,13 +118,12 @@ def metrics(self) -> list[Metric[Report]]:
119118
for metric in metrics
120119
]
121120

122-
deque(
123-
progress.track(
124-
as_completed(tasks),
125-
description=f"Computing {self.report.__class__.__name__} metrics",
126-
total=len(tasks),
127-
)
128-
)
121+
for task in progress.track(
122+
as_completed(tasks),
123+
description=f"Computing {self.report.__class__.__name__} metrics",
124+
total=len(tasks),
125+
):
126+
task.result()
129127

130128
return [metric for metric in metrics if metric.value is not None]
131129

skore-hub-project/tests/unit/report/test_cross_validation_report.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,23 @@ def test_metrics(self, payload):
338338
PredictTimeTrainStd,
339339
]
340340

341+
def test_metrics_raises_exception(self, monkeypatch, payload):
342+
"""
343+
Since metrics compute is multi-threaded, ensure that any exceptions thrown in a
344+
sub-thread are also thrown in the main thread.
345+
"""
346+
347+
def raise_exception(_):
348+
raise Exception("test_metrics_raises_exception")
349+
350+
monkeypatch.setattr(
351+
"skore_hub_project.metric.metric.CrossValidationReportMetric.compute",
352+
raise_exception,
353+
)
354+
355+
with raises(Exception, match="test_metrics_raises_exception"):
356+
list(map(type, payload.metrics))
357+
341358
@mark.filterwarnings(
342359
# ignore deprecation warnings generated by the way `pandas` is used by
343360
# `searborn`, which is a dependency of `skore`

skore-hub-project/tests/unit/report/test_estimator_report.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,23 @@ def test_metrics(self, payload):
122122
PredictTimeTrain,
123123
]
124124

125+
def test_metrics_raises_exception(self, monkeypatch, payload):
126+
"""
127+
Since metrics compute is multi-threaded, ensure that any exceptions thrown in a
128+
sub-thread are also thrown in the main thread.
129+
"""
130+
131+
def raise_exception(_):
132+
raise Exception("test_metrics_raises_exception")
133+
134+
monkeypatch.setattr(
135+
"skore_hub_project.metric.metric.EstimatorReportMetric.compute",
136+
raise_exception,
137+
)
138+
139+
with raises(Exception, match="test_metrics_raises_exception"):
140+
list(map(type, payload.metrics))
141+
125142
@mark.usefixtures("monkeypatch_artifact_hub_client")
126143
@mark.usefixtures("monkeypatch_upload_routes")
127144
def test_medias(self, payload):

0 commit comments

Comments
 (0)