5656 "SampleT" ,
5757 "SemivalueCoefficient" ,
5858 "SupervisedModel" ,
59- "TorchSupervisedModel " ,
59+ "SkorchSupervisedModel " ,
6060 "UtilityEvaluation" ,
6161 "ValueUpdate" ,
6262 "ValueUpdateT" ,
@@ -304,7 +304,7 @@ class SupervisedModel(Protocol[ArrayT, ArrayRetT]):
304304 `score()`.
305305 """
306306
307- def fit (self , x : ArrayT , y : ArrayT | None ):
307+ def fit (self , x : ArrayT , y : ArrayT ):
308308 """Fit the model to the data
309309
310310 Args:
@@ -324,7 +324,7 @@ def predict(self, x: ArrayT) -> ArrayRetT:
324324 """
325325 pass
326326
327- def score (self , x : ArrayT , y : ArrayT | None ) -> float :
327+ def score (self , x : ArrayT , y : ArrayT ) -> float :
328328 """Compute the score of the model given test data
329329
330330 Args:
@@ -370,15 +370,15 @@ def predict(self, x: ArrayT) -> ArrayRetT:
370370
371371
372372@runtime_checkable
373- class TorchSupervisedModel (Protocol ):
373+ class SkorchSupervisedModel (Protocol [ ArrayT ] ):
374374 """This is the standard sklearn Protocol with the methods `fit()`, `predict()`
375375 and `score()`, but accepting Tensors and with any additional info required.
376376 It is compatible with [skorch.net.NeuralNet][].
377377 """
378378
379379 device : str | torch_mod .device
380380
381- def fit (self , x : Tensor , y : Tensor | None ):
381+ def fit (self , x : ArrayT , y : Tensor ):
382382 """Fit the model to the data
383383
384384 Args:
@@ -387,7 +387,7 @@ def fit(self, x: Tensor, y: Tensor | None):
387387 """
388388 ...
389389
390- def predict (self , x : Tensor ) -> Tensor :
390+ def predict (self , x : ArrayT ) -> NDArray :
391391 """Compute predictions for the input
392392
393393 Args:
@@ -398,7 +398,7 @@ def predict(self, x: Tensor) -> Tensor:
398398 """
399399 ...
400400
401- def score (self , x : Tensor , y : Tensor | None ) -> float :
401+ def score (self , x : ArrayT , y : NDArray ) -> float :
402402 """Compute the score of the model given test data
403403
404404 Args:
0 commit comments