Skip to content

Commit 2bfcd66

Browse files
Fix skip existing not working (#140)
1 parent 122f39b commit 2bfcd66

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

models/rfd3/src/rfd3/engine.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ class RFD3InferenceConfig:
4848
diffusion_batch_size: int = 16
4949

5050
# RFD3 specific
51-
skip_existing: bool = False
52-
json_keys_subset: Optional[List[str]] = None
5351
skip_existing: bool = True
52+
json_keys_subset: Optional[List[str]] = None
5453
specification: Optional[dict] = field(default_factory=dict)
5554
inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
5655

@@ -216,6 +215,9 @@ def run(
216215
inputs=inputs,
217216
n_batches=n_batches,
218217
)
218+
if len(design_specifications) == 0:
219+
ranked_logger.info("No design specifications to run. Skipping.")
220+
return None
219221
ensure_inference_sampler_matches_design_spec(
220222
design_specifications, self.inference_sampler_overrides
221223
)
@@ -381,12 +383,18 @@ def _multiply_specifications(
381383
) -> Dict[str, dict | DesignInputSpecification]:
382384
# Find existing example IDS in output directory
383385
if exists(self.out_dir):
384-
existing_example_ids = set(
386+
existing_example_ids_ = set(
385387
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
386388
for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
387389
)
390+
existing_example_ids = set(
391+
[
392+
"_model_".join(eid.split("_model_")[:-1])
393+
for eid in existing_example_ids_
394+
]
395+
)
388396
ranked_logger.info(
389-
f"Found {len(existing_example_ids)} existing example IDs in the output directory."
397+
f"Found {len(existing_example_ids)} existing example IDs in the output directory ({len(existing_example_ids_)} total)."
390398
)
391399

392400
# Based on inputs, construct the specifications to loop through
@@ -405,7 +413,6 @@ def _multiply_specifications(
405413
for batch_id in range((n_batches) if exists(n_batches) else 1):
406414
# ... Example ID
407415
example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
408-
409416
if (
410417
self.skip_existing
411418
and exists(self.out_dir)

0 commit comments

Comments
 (0)