11from __future__ import annotations
22
3+ from typing import TYPE_CHECKING
4+
35import pytest
6+ from numpy .typing import NDArray
47from typing_extensions import Self
58
69from pydvl .utils import Seed , try_torch_import
710from pydvl .valuation .dataset import Dataset
8- from pydvl .valuation .games import DummyGameDataset , MinerGame , ShoesGame
11+ from pydvl .valuation .games import MinerGame , ShoesGame
912from pydvl .valuation .scorers import ClasswiseSupervisedScorer
1013from pydvl .valuation .utility .classwise import ClasswiseModelUtility
1114
12- torch = try_torch_import ()
13-
14- if torch is None :
15- pytest .skip ("PyTorch not available" , allow_module_level = True )
15+ if TYPE_CHECKING :
16+ import torch
17+ else :
18+ if (torch := try_torch_import ()) is None :
19+ pytest .skip ("PyTorch not available" , allow_module_level = True )
1620
1721
1822class TorchLinearClassifier :
@@ -48,8 +52,8 @@ def __init__(self, n_estimators: int, max_samples: float, random_state: Seed):
4852 self .n_estimators = n_estimators
4953 self .max_samples = max_samples
5054 self .random_state = random_state
51- self .estimators_ = []
52- self .estimators_samples_ = []
55+ self .estimators_ : list [ TorchLinearClassifier ] = []
56+ self .estimators_samples_ : list [ NDArray ] = []
5357
5458 def fit (self , X , y ):
5559 n_samples = X .shape [0 ]
@@ -129,21 +133,18 @@ def tensor_classwise_utility(tensor_test_dataset):
129133 )
130134
131135
132- class TensorDummyGameDataset (DummyGameDataset ):
136+ class TensorDummyGameDataset (Dataset [ torch . Tensor ] ):
133137 """Extends DummyGameDataset to use PyTorch tensors instead of NumPy arrays."""
134138
135139 def __init__ (self , n_players : int , description : str = "" ):
136140 x = torch .arange (0 , n_players , 1 ).reshape (- 1 , 1 ).float ()
137141 nil = torch .zeros_like (x )
138- (
139- Dataset .__init__ (
140- self ,
141- x ,
142- nil .clone (),
143- feature_names = ["x" ],
144- target_names = ["y" ],
145- description = description ,
146- ),
142+ super ().__init__ (
143+ x ,
144+ nil .clone (),
145+ feature_names = ["x" ],
146+ target_names = ["y" ],
147+ description = description ,
147148 )
148149
149150
@@ -152,12 +153,12 @@ class TensorMinerGame(MinerGame):
152153
153154 def __init__ (self , n_players : int ):
154155 super ().__init__ (n_players )
155- self .data = TensorDummyGameDataset (self .n_players , "Tensor Miner Game dataset" )
156+ self .data = TensorDummyGameDataset (self .n_players , "Tensor Miner Game dataset" ) # type: ignore[assignment]
156157
157158
158159class TensorShoesGame (ShoesGame ):
159160 """Extends ShoesGame to use PyTorch tensors."""
160161
161162 def __init__ (self , left : int , right : int ):
162163 super ().__init__ (left , right )
163- self .data = TensorDummyGameDataset (self .n_players , "Tensor Shoes Game dataset" )
164+ self .data = TensorDummyGameDataset (self .n_players , "Tensor Shoes Game dataset" ) # type: ignore[assignment]
0 commit comments