Skip to content

Commit 53985ac

Browse files
authored
Merge pull request #9 from lucasalvaa/deduplication_changes
Update deduplication.py
2 parents a234271 + f9bb4c3 commit 53985ac

File tree

1 file changed

+159
-118
lines changed

1 file changed

+159
-118
lines changed

src/preprocessing/deduplicate.py

Lines changed: 159 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,197 @@
1+
"""Module for dataset deduplication based on MD5 hashing."""
2+
3+
import argparse
4+
import csv
15
import hashlib
2-
import os
6+
import json
7+
import shutil
8+
import sys
39
from collections import defaultdict
10+
from pathlib import Path
411

12+
import yaml
513

6-
def get_image_hash(filepath: str) -> str:
7-
"""Generate an MD5 hash for a file to identify identical content.
8-
9-
Args:
10-
filepath: The path to the image file.
11-
12-
Return:
13-
The hexadecimal MD5 hash of the file.
1414

15-
"""
15+
def get_image_hash(filepath: Path) -> str:
16+
"""Generate an MD5 hash for a file."""
1617
hash_md5 = hashlib.md5()
1718
with open(filepath, "rb") as f:
1819
for chunk in iter(lambda: f.read(4096), b""):
1920
hash_md5.update(chunk)
2021
return hash_md5.hexdigest()
2122

2223

23-
def find_duplicates(root_dir: str) -> dict[str, list[str]]:
24-
"""Scan subdirectories and map hashes to file paths.
24+
def scan_dataset(root_dir: Path) -> dict[str, list[Path]]:
25+
"""Scan subdirectories and map ALL hashes to file paths."""
26+
hashes = defaultdict(list)
2527

26-
Args:
27-
root_dir: The root directory containing disease subdirectories.
28+
if not root_dir.exists():
29+
print(f"[-] Error: Input directory '{root_dir}' does not exist.")
30+
sys.exit(1)
2831

29-
Return:
30-
A dictionary mapping hashes to lists of duplicate file paths.
32+
categories = [d for d in root_dir.iterdir() if d.is_dir()]
33+
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
3134

32-
"""
33-
hashes = defaultdict(list)
34-
categories = [
35-
"demodicosis",
36-
"dermatitis",
37-
"fungal_infections",
38-
"healthy",
39-
"hypersensitivity",
40-
"ringworm",
41-
]
42-
43-
print(f"Scanning directories in {root_dir}...")
44-
45-
for category in categories:
46-
cat_path = os.path.join(root_dir, category)
47-
if not os.path.exists(cat_path):
48-
continue
35+
print(f"[*] Scanning directories in '{root_dir}'...")
4936

50-
for filename in os.listdir(cat_path):
51-
file_path = os.path.join(cat_path, filename)
52-
if os.path.isfile(file_path):
37+
for cat_path in categories:
38+
for file_path in cat_path.glob("*"):
39+
if file_path.is_file() and file_path.suffix.lower() in valid_extensions:
5340
file_hash = get_image_hash(file_path)
5441
hashes[file_hash].append(file_path)
5542

56-
return {h: paths for h, paths in hashes.items() if len(paths) > 1}
43+
return hashes
5744

5845

59-
def run_deduplication(dir_name: str) -> None: # noqa: C901
60-
"""Find duplicate images, save a report and prompt for deletion.
46+
def generate_csv_report(
47+
duplicates: dict[str, list[Path]], report_path: Path
48+
) -> None:
49+
"""Generate a CSV report listing all duplicates found."""
50+
print(f"[*] Generating detailed CSV report at '{report_path}'...")
6151

62-
This function identifies duplicates, tracks cross-class inconsistencies,
63-
and logs the number of removals per category.
52+
header = ["Hash", "Issue_Type", "Class", "File_Path"]
6453

65-
Args:
66-
dir_name: The name of the subdirectory within "data" to search.
54+
with open(report_path, "w", newline="") as f:
55+
writer = csv.writer(f)
56+
writer.writerow(header)
6757

68-
"""
69-
root_data_dir = os.path.join("../../data", dir_name)
70-
report_path = "report/duplicates.txt"
71-
os.makedirs("report", exist_ok=True)
58+
for file_hash, paths in duplicates.items():
59+
classes_involved = {p.parent.name for p in paths}
7260

73-
duplicates = find_duplicates(root_data_dir)
61+
if len(classes_involved) > 1:
62+
issue_type = "CROSS_CLASS_CONFLICT"
63+
else:
64+
issue_type = "SAME_CLASS_REDUNDANCY"
7465

75-
if not duplicates:
76-
print("No duplicate images found.")
77-
return
66+
for p in paths:
67+
writer.writerow([file_hash, issue_type, p.parent.name, str(p)])
7868

79-
files_to_delete: list[str] = []
8069

81-
inconsistent_cross_class_count = 0
82-
removed_per_class = defaultdict(int)
83-
total_deleted_per_class = defaultdict(int)
70+
def create_clean_dataset(
71+
all_hashes: dict[str, list[Path]], src_root: Path, dest_root: Path
72+
) -> dict:
73+
"""Copy valid files to a new directory."""
74+
if dest_root.exists():
75+
shutil.rmtree(dest_root)
76+
dest_root.mkdir(parents=True)
8477

85-
with open(report_path, "w") as f:
86-
f.write("DUPLICATE IMAGES REPORT\n" + "=" * 24 + "\n")
78+
stats = {
79+
"unique_files_copied": 0,
80+
"redundant_files_skipped": 0,
81+
"cross_class_conflict_skipped": 0,
82+
"removed_per_class": defaultdict(int),
83+
}
8784

88-
for h, paths in duplicates.items():
89-
f.write(f"\nHash: {h}\n")
85+
print(f"[*] Creating clean dataset in '{dest_root}'...")
9086

87+
for _, paths in all_hashes.items():
88+
classes_involved = {p.parent.name for p in paths}
89+
90+
# Case 1: Cross-class duplicates -> SKIP ALL
91+
if len(classes_involved) > 1:
92+
stats["cross_class_conflict_skipped"] += len(paths)
9193
for p in paths:
92-
f.write(f" - {p}\n")
94+
stats["removed_per_class"][p.parent.name] += 1
95+
continue
9396

94-
classes_involved = {os.path.basename(os.path.dirname(p)) for p in paths}
97+
# Case 2: Same-class duplicates -> COPY ONE
98+
if len(paths) > 1:
99+
num_redundant = len(paths) - 1
100+
stats["redundant_files_skipped"] += num_redundant
101+
stats["removed_per_class"][paths[0].parent.name] += num_redundant
102+
paths_to_copy = [paths[0]]
103+
else:
104+
# Case 3: Unique -> COPY
105+
stats["unique_files_copied"] += 1
106+
paths_to_copy = paths
107+
108+
# Copy
109+
for src_path in paths_to_copy:
110+
rel_path = src_path.relative_to(src_root)
111+
dest_path = dest_root / rel_path
112+
113+
dest_path.parent.mkdir(parents=True, exist_ok=True)
114+
shutil.copy2(src_path, dest_path)
115+
116+
return stats
117+
118+
119+
def generate_summary_report(stats: dict, report_path: Path) -> None:
120+
"""Write the final summary statistics to a JSON file."""
121+
print(f"[*] Saving summary report to '{report_path}'...")
122+
123+
total_removed = (
124+
stats['cross_class_conflict_skipped'] + stats['redundant_files_skipped']
125+
)
126+
127+
summary_data = {
128+
"valid_images_preserved": stats['unique_files_copied'],
129+
"images_removed_cross_class": stats['cross_class_conflict_skipped'],
130+
"images_removed_redundancy": stats['redundant_files_skipped'],
131+
"total_removed": total_removed,
132+
"removals_per_class": dict(stats["removed_per_class"]),
133+
}
134+
135+
with open(report_path, "w") as f:
136+
json.dump(summary_data, f, indent=4)
137+
138+
139+
def main() -> None:
140+
"""Entry point supporting both Config file and CLI arguments."""
141+
parser = argparse.ArgumentParser(description="Dataset Deduplication")
142+
143+
parser.add_argument(
144+
"--config", type=str, default=None, help="Path to params.yaml"
145+
)
146+
147+
parser.add_argument(
148+
"--input", type=str, default=None, help="Input directory (raw data)"
149+
)
150+
parser.add_argument(
151+
"--output",
152+
type=str,
153+
default=None,
154+
help="Output directory (deduplicated data)"
155+
)
156+
157+
args = parser.parse_args()
158+
159+
if args.config:
160+
print(f"[*] Loading configuration from {args.config}")
161+
with open(args.config) as f:
162+
config = yaml.safe_load(f)
95163

96-
if len(classes_involved) > 1:
97-
# Cross-class duplicates
98-
inconsistent_cross_class_count += len(paths)
99-
f.write(" [!] CROSS-CLASS ERROR: Marking ALL for deletion.\n")
100-
files_to_delete.extend(paths)
101-
for p in paths:
102-
total_deleted_per_class[os.path.basename(os.path.dirname(p))] += 1
103-
else:
104-
# Same-class duplicates
105-
current_class = next(iter(classes_involved))
106-
duplicates_in_folder = paths[1:]
107-
108-
f.write(f" [i] Same-class duplicates: Keeping {paths[0]}\n")
109-
110-
removed_per_class[current_class] += len(duplicates_in_folder)
111-
total_deleted_per_class[current_class] += len(duplicates_in_folder)
112-
files_to_delete.extend(duplicates_in_folder)
113-
114-
# Add deduplication stats in the report
115-
f.write("\n" + "=" * 30 + "\n")
116-
f.write("SUMMARY STATISTICS\n")
117-
f.write("=" * 30 + "\n")
118-
f.write(
119-
f"Inconsistent cross-class images deleted: "
120-
f"{inconsistent_cross_class_count}\n\n"
121-
)
122-
123-
f.write("Redundant duplicates removed per class (same directory):\n")
124-
for category, count in removed_per_class.items():
125-
f.write(f" - {category}: {count}\n")
126-
127-
f.write("\nTotal images to be deleted per class (inconsistent + redundant):\n")
128-
for category, count in total_deleted_per_class.items():
129-
f.write(f" - {category}: {count}\n")
130-
f.write(f"\nTOTAL FILES TO BE DELETED: {len(files_to_delete)}\n")
131-
132-
print(f"\nFound {len(duplicates)} groups of duplicate hashes.")
133-
print(f"Total files marked for deletion: {len(files_to_delete)}")
134-
print(f"Summary saved to: {report_path}")
135-
136-
if not files_to_delete:
137-
return
138-
139-
confirm = input(
140-
f"Proceed with deletion of {len(files_to_delete)} files? [y/N]: "
141-
).lower()
142-
143-
if confirm == "n":
144-
print("\nOperation cancelled.")
145-
return
146-
147-
for path in files_to_delete:
148164
try:
149-
os.remove(path)
150-
except OSError as e:
151-
print(f"Error deleting {path}: {e}")
152-
print("\nClean-up complete.")
165+
input_dir = Path(config["data"]["raw_path"])
166+
output_dir = Path(config["data"]["dedup_path"])
167+
except KeyError as e:
168+
print(f"[-] Error: Key {e} not found in {args.config}")
169+
sys.exit(1)
170+
171+
elif args.input and args.output:
172+
print("[*] Using command line arguments")
173+
input_dir = Path(args.input)
174+
output_dir = Path(args.output)
175+
176+
else:
177+
parser.error("You must provide either --config OR both --input and --output.")
178+
179+
report_dir = Path("reports")
180+
report_dir.mkdir(exist_ok=True)
181+
182+
csv_report = report_dir / "duplicates_log.csv"
183+
summary_report = report_dir / "deduplication_summary.json"
184+
185+
all_hashes = scan_dataset(input_dir)
186+
187+
duplicates_only = {h: p for h, p in all_hashes.items() if len(p) > 1}
188+
generate_csv_report(duplicates_only, csv_report)
189+
190+
stats = create_clean_dataset(all_hashes, input_dir, output_dir)
191+
generate_summary_report(stats, summary_report)
192+
193+
print("\n[+] Deduplication complete.")
153194

154195

155196
if __name__ == "__main__":
156-
run_deduplication("raw")
197+
main()

0 commit comments

Comments
 (0)