Skip to content

Commit 15c67ec

Browse files
authored
Merge pull request #443 from aai-institute/fix/value-init
Fix data_names in ValuationResult.zeros()
2 parents ac3ed99 + 0d8893c commit 15c67ec

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Fix initialization of `data_names` in `ValuationResult.zeros()`
6+
[PR #443](https://github.com/aai-institute/pyDVL/pull/443)
57
- Using pytest-xdist for faster local tests
68
[PR #440](https://github.com/aai-institute/pyDVL/pull/440)
79
- Added `AntitheticPermutationSampler`

src/pydvl/reporting/plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def plot_ci_array(
104104
means = np.mean(data, axis=0)
105105
variances = np.var(data, axis=0, ddof=1)
106106

107-
dummy: ValuationResult[np.int_, str] = ValuationResult(
107+
dummy: ValuationResult[np.int_, np.object_] = ValuationResult(
108108
algorithm="dummy",
109109
values=means,
110110
variances=variances,

src/pydvl/value/result.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -784,13 +784,17 @@ def zeros(
784784
indices = np.arange(n_samples, dtype=np.int_)
785785
else:
786786
indices = np.array(indices, dtype=np.int_)
787+
788+
if data_names is None:
789+
data_names = np.array(indices)
790+
else:
791+
data_names = np.array(data_names)
792+
787793
return cls(
788794
algorithm=algorithm,
789795
status=Status.Pending,
790796
indices=indices,
791-
data_names=np.array(data_names, dtype=object)
792-
if data_names is not None
793-
else np.empty_like(indices, dtype=object),
797+
data_names=data_names,
794798
values=np.zeros(len(indices)),
795799
variances=np.zeros(len(indices)),
796800
counts=np.zeros(len(indices), dtype=np.int_),

0 commit comments

Comments
 (0)