Skip to content

Commit d692852

Browse files
committed
Add SkorchSupervisedScorer to move tensors to cpu before scoring
1 parent 1a7a18a commit d692852

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

docs/value/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ tensors for data valuation. The implementation follows these key principles:
127127
1. **Type Preservation**: The valuation methods maintain the input data type
128128
throughout computations, whether you provide NumPy arrays or PyTorch tensors
129129
when constructing the [Dataset][pydvl.valuation.dataset.Dataset].
130-
2. **Transparent Usage**: The API remains the same regardless of the input type.
131-
Simply provide your data as tensors. The main difference is that the torch
130+
2. **Transparent Usage**: The API remains the same regardless of the input type -
131+
simply provide your data as tensors. The main difference is that the torch
132132
model must be wrapped in a class compatible with the protocol
133133
[TorchSupervisedModel][pydvl.valuation.types.TorchSupervisedModel].
134134
!!! tip "Wrapping torch models"

notebooks/support/banzhaf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from pydvl.utils import timed
1212
from pydvl.utils.monitor import end_memory_monitoring, start_memory_monitoring
13-
from pydvl.valuation.types import TorchSupervisedModel
1413

1514
from .datasets import load_digits_dataset
1615

@@ -70,7 +69,7 @@ def forward(self, x):
7069
return self.layers(x)
7170

7271

73-
class TorchClassifierModel(TorchSupervisedModel):
72+
class TorchClassifierModel:
7473
"""This class wraps a torch classification model to comply with the
7574
[SupervisedModel][pydvl.utils.types.SupervisedModel] interface expected by pyDVL,
7675
and takes care of the training and evaluation of the model.
@@ -286,7 +285,7 @@ def run(_config):
286285
verbose=False,
287286
)
288287

289-
# scorer = SupervisedScorer(model, test, default=0.0, range=(0.0, 1.0))
288+
# scorer = SkorchSupervisedScorer(model, test, default=0.0, range=(0.0, 1.0))
290289

291290
# utility = ModelUtility(
292291
# model,

src/pydvl/valuation/scorers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
from .base import *
1515
from .classwise import *
1616
from .supervised import *
17+
from .torchscorer import *
1718
from .utils import *
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
import torch
6+
7+
from pydvl.valuation.scorers.supervised import SupervisedScorer
8+
from pydvl.valuation.types import SkorchSupervisedModel
9+
10+
__all__ = ["SkorchSupervisedScorer"]
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class SkorchSupervisedScorer(SupervisedScorer[SkorchSupervisedModel, torch.Tensor]):
16+
"""Scorer for Skorch models.
17+
18+
Because skorch models scorer() requires a numpy array to test against, this
19+
class moves tensors to cpu before scoring.
20+
"""
21+
22+
def __call__(self, model: SkorchSupervisedModel) -> float:
23+
x, y = self.test_data.data()
24+
if torch.is_tensor(y):
25+
y = y.cpu().numpy()
26+
return float(self._scorer(model, x, y))

src/pydvl/valuation/types.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"SampleT",
5757
"SemivalueCoefficient",
5858
"SupervisedModel",
59-
"TorchSupervisedModel",
59+
"SkorchSupervisedModel",
6060
"UtilityEvaluation",
6161
"ValueUpdate",
6262
"ValueUpdateT",
@@ -304,7 +304,7 @@ class SupervisedModel(Protocol[ArrayT, ArrayRetT]):
304304
`score()`.
305305
"""
306306

307-
def fit(self, x: ArrayT, y: ArrayT | None):
307+
def fit(self, x: ArrayT, y: ArrayT):
308308
"""Fit the model to the data
309309
310310
Args:
@@ -324,7 +324,7 @@ def predict(self, x: ArrayT) -> ArrayRetT:
324324
"""
325325
pass
326326

327-
def score(self, x: ArrayT, y: ArrayT | None) -> float:
327+
def score(self, x: ArrayT, y: ArrayT) -> float:
328328
"""Compute the score of the model given test data
329329
330330
Args:
@@ -370,15 +370,15 @@ def predict(self, x: ArrayT) -> ArrayRetT:
370370

371371

372372
@runtime_checkable
373-
class TorchSupervisedModel(Protocol):
373+
class SkorchSupervisedModel(Protocol[ArrayT]):
374374
"""This is the standard sklearn Protocol with the methods `fit()`, `predict()`
375375
and `score()`, but accepting Tensors and with any additional info required.
376376
It is compatible with [skorch.net.NeuralNet][].
377377
"""
378378

379379
device: str | torch_mod.device
380380

381-
def fit(self, x: Tensor, y: Tensor | None):
381+
def fit(self, x: ArrayT, y: Tensor):
382382
"""Fit the model to the data
383383
384384
Args:
@@ -387,7 +387,7 @@ def fit(self, x: Tensor, y: Tensor | None):
387387
"""
388388
...
389389

390-
def predict(self, x: Tensor) -> Tensor:
390+
def predict(self, x: ArrayT) -> NDArray:
391391
"""Compute predictions for the input
392392
393393
Args:
@@ -398,7 +398,7 @@ def predict(self, x: Tensor) -> Tensor:
398398
"""
399399
...
400400

401-
def score(self, x: Tensor, y: Tensor | None) -> float:
401+
def score(self, x: ArrayT, y: NDArray) -> float:
402402
"""Compute the score of the model given test data
403403
404404
Args:

0 commit comments

Comments
 (0)