diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index aae1f753..6be5c534 100755 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -327,6 +327,41 @@ def test_restricted_samples_float(tmp_path): assert data is not None +def test_restricted_indices(tmp_path): + filename = tmp_path / "dummy_well_data.hdf5" + write_dummy_data(filename) + # Create dataset without restrictions to get the full length + full_dataset = WellDataset( + path=str(tmp_path), + use_normalization=False, + return_grid=True, + ) + full_length = len(full_dataset) + + # Exclude specific indices + indices_to_exclude = [0, 1, 5, 10] + dataset = WellDataset( + path=str(tmp_path), + use_normalization=False, + return_grid=True, + restrict_indices=indices_to_exclude, + ) + + expected_length = full_length - len(indices_to_exclude) + assert ( + len(dataset) == expected_length + ), f"Restricted dataset should contain {expected_length} samples (18 - 4 excluded), but found {len(dataset)}" + + # Verify we can still access data + data = dataset[0] + assert data is not None + + # Verify that the restriction set doesn't contain excluded indices + assert all( + idx not in indices_to_exclude for idx in dataset.restriction_set + ), "Restriction set should not contain any excluded indices" + + @pytest.mark.parametrize("start_output_steps_at_t", [-1, 4]) def test_full_trajectory_mode_minimum_steps(tmp_path, start_output_steps_at_t): filename = tmp_path / "dummy_well_data.hdf5" diff --git a/the_well/data/datasets.py b/the_well/data/datasets.py index ccb6ab62..4b185c57 100755 --- a/the_well/data/datasets.py +++ b/the_well/data/datasets.py @@ -152,6 +152,8 @@ class WellDataset(Dataset): Whether to restrict the number of trajectories to a subset of the dataset. Integer inputs restrict to a number. Float to a percentage. restrict_num_samples: Whether to restrict the number of samples to a subset of the dataset. Integer inputs restrict to a number. Float to a percentage. + restrict_indices: + List of global indices to skip/exclude from the dataset. Only one restriction type should be used. restriction_seed: Seed used to generate restriction set. Necessary to ensure same set is sampled across runs. cache_small: @@ -205,6 +207,7 @@ def __init__( flatten_tensors: bool = True, restrict_num_trajectories: Optional[float | int] = None, restrict_num_samples: Optional[float | int] = None, + restrict_indices: Optional[list[int]] = None, restriction_seed: int = 0, cache_small: bool = True, max_cache_size: float = 1e9, @@ -272,6 +275,7 @@ def __init__( self.flatten_tensors = flatten_tensors self.restrict_num_trajectories = restrict_num_trajectories self.restrict_num_samples = restrict_num_samples + self.restrict_indices = restrict_indices self.restriction_seed = restriction_seed self.return_grid = return_grid self.normalize_time_grid = normalize_time_grid @@ -341,23 +345,41 @@ def __init__( # If we're limiting number of samples/trajectories... self.restriction_set = None - if restrict_num_samples is not None or restrict_num_trajectories is not None: + if ( + restrict_num_samples is not None + or restrict_num_trajectories is not None + or restrict_indices is not None + ): self._build_restriction_set( - restrict_num_samples, restrict_num_trajectories, restriction_seed + restrict_num_samples, + restrict_num_trajectories, + restrict_indices, + restriction_seed, ) + def _build_restriction_set( self, restrict_num_samples: Optional[int | float], restrict_num_trajectories: Optional[int | float], + restrict_indices: Optional[list[int]], seed: int, ): """Builds a restriction set for the dataset based on the specified restrictions""" gen = np.random.default_rng(seed) - if restrict_num_samples is not None and restrict_num_trajectories is not None: + non_none_count = sum( + [ + restrict_num_samples is not None, + restrict_num_trajectories is not None, + restrict_indices is not None, + ] + ) + + if non_none_count > 1: warnings.warn( - "Both restrict_num_samples and restrict_num_trajectories are set. Using restrict_num_samples." + "More than one restriction is set. Using restrict_num_samples." ) + global_indices = np.arange(self.len) if restrict_num_trajectories is not None: # Compute total number of trajectories, collect all indices corresponding to them, then select a subset @@ -395,6 +417,17 @@ def _build_restriction_set( current_index += self.n_windows_per_trajectory[file_index] global_indices = np.array(global_indices) + if restrict_indices is not None: + # Skip the specified indices by creating a mask + skip_set = set(restrict_indices) + # Filter out indices to skip, keeping only those not in skip_set + global_indices = np.array( + [idx for idx in global_indices if idx not in skip_set] + ) + if len(global_indices) == 0: + warnings.warn( + "All indices were excluded by restrict_indices. Dataset will be empty." + ) if restrict_num_samples is not None: if 0.0 < restrict_num_samples < 1.0: restrict_num_samples = int(self.len * restrict_num_samples)