Skip to content

Commit e3f96a3

Browse files
committed
Fix types
1 parent 48883f1 commit e3f96a3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/pydvl/reporting/plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def plot_ci_array(
104104
means = np.mean(data, axis=0)
105105
variances = np.var(data, axis=0, ddof=1)
106106

107-
dummy: ValuationResult[np.int_, str] = ValuationResult(
107+
dummy = ValuationResult[np.int_, np.object_](
108108
algorithm="dummy",
109109
values=means,
110110
variances=variances,

src/pydvl/value/semivalues.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
import warnings
9393
from enum import Enum
9494
from itertools import islice
95-
from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast
95+
from typing import Iterable, List, Optional, Protocol, Tuple, Type, cast
9696

9797
import scipy as sp
9898
from deprecate import deprecated
@@ -143,7 +143,7 @@ def __call__(self, n: int, k: int) -> float:
143143

144144

145145
def _marginal(
146-
u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT]
146+
u: Utility, coefficient: SVCoefficient, samples: Iterable[SampleT]
147147
) -> Tuple[MarginalT, ...]:
148148
"""Computation of marginal utility. This is a helper function for
149149
[compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues].
@@ -278,7 +278,7 @@ def compute_generic_semivalues(
278278
raise StopIteration
279279

280280
# Filter out samples for indices that have already converged
281-
filtered_samples = samples
281+
filtered_samples: Iterable = samples
282282
if skip_converged and len(done.converged) > 0:
283283
# t[0] is the index for the sample
284284
filtered_samples = filter(

0 commit comments

Comments
 (0)