Skip to content

Commit 48883f1

Browse files
committed
Filter and avoid conversion
1 parent 128e29f commit 48883f1

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/pydvl/value/semivalues.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,19 +273,17 @@ def compute_generic_semivalues(
273273
# Ensure that we always have n_submitted_jobs running
274274
try:
275275
while len(pending) < n_submitted_jobs:
276-
samples = dict(islice(sampler_it, batch_size))
276+
samples = tuple(islice(sampler_it, batch_size))
277277
if len(samples) == 0:
278278
raise StopIteration
279279

280280
# Filter out samples for indices that have already converged
281+
filtered_samples = samples
281282
if skip_converged and len(done.converged) > 0:
282-
filtered_samples = tuple(
283-
(idx, sample)
284-
for idx, sample in samples.items()
285-
if not done.converged[idx]
283+
# t[0] is the index for the sample
284+
filtered_samples = filter(
285+
lambda t: not done.converged[t[0]], samples
286286
)
287-
else:
288-
filtered_samples = tuple(samples.items())
289287

290288
if filtered_samples:
291289
pending.add(

0 commit comments

Comments
 (0)