Skip to content

Commit 2a395ee

Browse files
committed
Add number of updates to results dataframe
1 parent dca88a7 commit 2a395ee

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

src/pydvl/value/result.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
samples random values uniformly.
4242
4343
"""
44+
4445
from __future__ import annotations
4546

4647
import collections.abc
@@ -357,16 +358,13 @@ def __getattr__(self, attr: str) -> Any:
357358
) from e
358359

359360
@overload
360-
def __getitem__(self, key: int) -> ValueItem:
361-
...
361+
def __getitem__(self, key: int) -> ValueItem: ...
362362

363363
@overload
364-
def __getitem__(self, key: slice) -> List[ValueItem]:
365-
...
364+
def __getitem__(self, key: slice) -> List[ValueItem]: ...
366365

367366
@overload
368-
def __getitem__(self, key: Iterable[int]) -> List[ValueItem]:
369-
...
367+
def __getitem__(self, key: Iterable[int]) -> List[ValueItem]: ...
370368

371369
def __getitem__(
372370
self, key: Union[slice, Iterable[int], int]
@@ -392,16 +390,13 @@ def __getitem__(
392390
raise TypeError("Indices must be integers, iterable or slices")
393391

394392
@overload
395-
def __setitem__(self, key: int, value: ValueItem) -> None:
396-
...
393+
def __setitem__(self, key: int, value: ValueItem) -> None: ...
397394

398395
@overload
399-
def __setitem__(self, key: slice, value: ValueItem) -> None:
400-
...
396+
def __setitem__(self, key: slice, value: ValueItem) -> None: ...
401397

402398
@overload
403-
def __setitem__(self, key: Iterable[int], value: ValueItem) -> None:
404-
...
399+
def __setitem__(self, key: Iterable[int], value: ValueItem) -> None: ...
405400

406401
def __setitem__(
407402
self, key: Union[slice, Iterable[int], int], value: ValueItem
@@ -676,12 +671,15 @@ def to_dataframe(
676671
column = column or self._algorithm
677672
df = pd.DataFrame(
678673
self._values[self._sort_positions],
679-
index=self._names[self._sort_positions]
680-
if use_names
681-
else self._indices[self._sort_positions],
674+
index=(
675+
self._names[self._sort_positions]
676+
if use_names
677+
else self._indices[self._sort_positions]
678+
),
682679
columns=[column],
683680
)
684681
df[column + "_stderr"] = self.stderr[self._sort_positions]
682+
df[column + "_updates"] = self.counts[self._sort_positions]
685683
return df
686684

687685
@classmethod

0 commit comments

Comments
 (0)