|
| 1 | +# Copyright 2025 MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import logging |
| 13 | +from pathlib import Path |
| 14 | +from typing import List, Union |
| 15 | + |
| 16 | +from monai.deploy.core import ConditionType, Fragment, Operator, OperatorSpec |
| 17 | + |
| 18 | + |
| 19 | +class GenericDirectoryScanner(Operator): |
| 20 | + """Scan a directory for files matching specified extensions and emit file paths one by one. |
| 21 | + |
| 22 | + This operator provides a generic way to iterate through files in a directory, |
| 23 | + emitting one file path at a time. It can be chained with file-specific loaders |
| 24 | + to create flexible data loading pipelines. |
| 25 | + |
| 26 | + Named Outputs: |
| 27 | + file_path: Path to the current file being processed |
| 28 | + filename: Name of the current file (without extension) |
| 29 | + file_index: Current file index (0-based) |
| 30 | + total_files: Total number of files found |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + fragment: Fragment, |
| 36 | + *args, |
| 37 | + input_folder: Union[str, Path], |
| 38 | + file_extensions: List[str], |
| 39 | + recursive: bool = True, |
| 40 | + case_sensitive: bool = False, |
| 41 | + **kwargs, |
| 42 | + ) -> None: |
| 43 | + """Initialize the GenericDirectoryScanner. |
| 44 | +
|
| 45 | + Args: |
| 46 | + fragment: An instance of the Application class |
| 47 | + input_folder: Path to folder containing files to scan |
| 48 | + file_extensions: List of file extensions to scan for (e.g., ['.jpg', '.png']) |
| 49 | + recursive: If True, scan subdirectories recursively |
| 50 | + case_sensitive: If True, perform case-sensitive extension matching |
| 51 | + """ |
| 52 | + self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) |
| 53 | + self._input_folder = Path(input_folder) |
| 54 | + self._file_extensions = [ext if ext.startswith('.') else f'.{ext}' for ext in file_extensions] |
| 55 | + self._recursive = bool(recursive) |
| 56 | + self._case_sensitive = bool(case_sensitive) |
| 57 | + |
| 58 | + # State tracking |
| 59 | + self._files = [] |
| 60 | + self._current_index = 0 |
| 61 | + |
| 62 | + super().__init__(fragment, *args, **kwargs) |
| 63 | + |
| 64 | + def _find_files(self) -> List[Path]: |
| 65 | + """Find all files matching the specified extensions.""" |
| 66 | + files = [] |
| 67 | + |
| 68 | + # Normalize extensions for comparison |
| 69 | + if not self._case_sensitive: |
| 70 | + extensions = [ext.lower() for ext in self._file_extensions] |
| 71 | + else: |
| 72 | + extensions = self._file_extensions |
| 73 | + |
| 74 | + # Choose search method based on recursive flag |
| 75 | + if self._recursive: |
| 76 | + search_pattern = "**/*" |
| 77 | + search_method = self._input_folder.rglob |
| 78 | + else: |
| 79 | + search_pattern = "*" |
| 80 | + search_method = self._input_folder.glob |
| 81 | + |
| 82 | + # Find all files and filter by extension |
| 83 | + for file_path in search_method(search_pattern): |
| 84 | + if file_path.is_file(): |
| 85 | + # Skip hidden files (starting with .) to avoid macOS metadata files like ._file.nii.gz |
| 86 | + if file_path.name.startswith('.'): |
| 87 | + continue |
| 88 | + |
| 89 | + # Handle compound extensions like .nii.gz by checking if filename ends with any extension |
| 90 | + filename = file_path.name |
| 91 | + if not self._case_sensitive: |
| 92 | + filename = filename.lower() |
| 93 | + |
| 94 | + # Check if filename ends with any of the specified extensions |
| 95 | + for ext in extensions: |
| 96 | + if filename.endswith(ext): |
| 97 | + files.append(file_path) |
| 98 | + break # Only add once even if multiple extensions match |
| 99 | + |
| 100 | + # Sort files for consistent ordering |
| 101 | + files.sort() |
| 102 | + return files |
| 103 | + |
| 104 | + def setup(self, spec: OperatorSpec): |
| 105 | + """Define the operator outputs.""" |
| 106 | + spec.output("file_path") |
| 107 | + spec.output("filename") |
| 108 | + spec.output("file_index").condition(ConditionType.NONE) |
| 109 | + spec.output("total_files").condition(ConditionType.NONE) |
| 110 | + |
| 111 | + # Pre-initialize the files list |
| 112 | + if not self._input_folder.is_dir(): |
| 113 | + raise ValueError(f"Input folder {self._input_folder} is not a directory") |
| 114 | + |
| 115 | + self._files = self._find_files() |
| 116 | + self._current_index = 0 |
| 117 | + |
| 118 | + if not self._files: |
| 119 | + self._logger.warning( |
| 120 | + f"No files found in {self._input_folder} with extensions {self._file_extensions}" |
| 121 | + ) |
| 122 | + else: |
| 123 | + self._logger.info( |
| 124 | + f"Found {len(self._files)} files to process with extensions {self._file_extensions}" |
| 125 | + ) |
| 126 | + |
| 127 | + def compute(self, op_input, op_output, context): |
| 128 | + """Emit the next file path.""" |
| 129 | + |
| 130 | + # Check if we have more files to process |
| 131 | + if self._current_index >= len(self._files): |
| 132 | + # No more files to process |
| 133 | + self._logger.info("All files have been processed") |
| 134 | + self.fragment.stop_execution() |
| 135 | + return |
| 136 | + |
| 137 | + # Get the current file path |
| 138 | + file_path = self._files[self._current_index] |
| 139 | + |
| 140 | + try: |
| 141 | + # Emit file information |
| 142 | + op_output.emit(str(file_path), "file_path") |
| 143 | + op_output.emit(file_path.stem, "filename") |
| 144 | + op_output.emit(self._current_index, "file_index") |
| 145 | + op_output.emit(len(self._files), "total_files") |
| 146 | + |
| 147 | + self._logger.info( |
| 148 | + f"Emitted file: {file_path.name} ({self._current_index + 1}/{len(self._files)})" |
| 149 | + ) |
| 150 | + |
| 151 | + except Exception as e: |
| 152 | + self._logger.error(f"Failed to process file {file_path}: {e}") |
| 153 | + |
| 154 | + # Move to the next file |
| 155 | + self._current_index += 1 |
| 156 | + |
| 157 | + |
| 158 | +def test(): |
| 159 | + """Test the GenericDirectoryScanner operator.""" |
| 160 | + import tempfile |
| 161 | + |
| 162 | + # Create a temporary directory with test files |
| 163 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 164 | + temp_path = Path(temp_dir) |
| 165 | + |
| 166 | + # Create test files with different extensions |
| 167 | + test_files = [ |
| 168 | + "test1.jpg", "test2.png", "test3.nii", "test4.nii.gz", |
| 169 | + "test5.txt", "test6.jpeg" |
| 170 | + ] |
| 171 | + |
| 172 | + for filename in test_files: |
| 173 | + (temp_path / filename).touch() |
| 174 | + |
| 175 | + # Create a subdirectory with more files |
| 176 | + sub_dir = temp_path / "subdir" |
| 177 | + sub_dir.mkdir() |
| 178 | + (sub_dir / "sub_test.jpg").touch() |
| 179 | + (sub_dir / "sub_test.nii").touch() |
| 180 | + |
| 181 | + # Test the operator with image extensions |
| 182 | + fragment = Fragment() |
| 183 | + scanner = GenericDirectoryScanner( |
| 184 | + fragment, |
| 185 | + input_folder=temp_path, |
| 186 | + file_extensions=['.jpg', '.jpeg', '.png'], |
| 187 | + recursive=True |
| 188 | + ) |
| 189 | + |
| 190 | + # Simulate setup |
| 191 | + from monai.deploy.core import OperatorSpec |
| 192 | + spec = OperatorSpec() |
| 193 | + scanner.setup(spec) |
| 194 | + |
| 195 | + print(f"Found {len(scanner._files)} image files") |
| 196 | + |
| 197 | + # Simulate compute calls |
| 198 | + class MockOutput: |
| 199 | + def emit(self, data, name): |
| 200 | + print(f"Emitted {name}: {data}") |
| 201 | + |
| 202 | + mock_output = MockOutput() |
| 203 | + |
| 204 | + # Process a few files |
| 205 | + for i in range(min(3, len(scanner._files))): |
| 206 | + print(f"\n--- Processing file {i+1} ---") |
| 207 | + scanner.compute(None, mock_output, None) |
| 208 | + |
| 209 | + |
| 210 | +if __name__ == "__main__": |
| 211 | + test() |
0 commit comments