|
| 1 | +import os |
| 2 | +import logging |
| 3 | +from multiprocessing import Pool, cpu_count |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +import pandas as pd |
| 7 | +from tqdm import tqdm |
| 8 | +from pathlib import Path |
| 9 | +import plotly.express as px |
| 10 | +import plotly.graph_objects as go |
| 11 | + |
| 12 | +from directory_treemap.utils.file_utils import human_readable_size |
| 13 | + |
| 14 | + |
| 15 | +class DirectoryTreemap: |
| 16 | + """Class to scan a directory and generate a treemap visualization of its contents. |
| 17 | + """ |
| 18 | + |
| 19 | + def __init__(self, base_path: Path, |
| 20 | + output_dir: Path, |
| 21 | + parallel: bool = False): |
| 22 | + """Initialize the DirectoryTreemap instance. |
| 23 | +
|
| 24 | + Args: |
| 25 | + base_path: The base directory to scan. |
| 26 | + output_dir: The directory to save the output HTML file. |
| 27 | + parallel: Whether to use parallel processing for scanning. |
| 28 | + """ |
| 29 | + self.base_path: Path = Path(base_path) |
| 30 | + self.output_dir: Path = Path(output_dir) |
| 31 | + self.max_depth: Optional[int] = None |
| 32 | + self.max_files: Optional[int] = None |
| 33 | + self.parallel: bool = parallel |
| 34 | + self.file_data: list = [] |
| 35 | + self.total_size: float = 0 |
| 36 | + self.df: Optional[pd.DataFrame] = None |
| 37 | + self.fig: Optional[go.Figure] = None |
| 38 | + self.report_filename: Optional[str] = None |
| 39 | + |
| 40 | + self.path_columns: list = [] |
| 41 | + self.limit_depth = None |
| 42 | + self.scan_time = None |
| 43 | + |
| 44 | + # Setup logging |
| 45 | + log_file = self.output_dir / 'dirtreemap.log' |
| 46 | + logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s') |
| 47 | + logging.info( |
| 48 | + f"Initialized DirectoryTreemap with base_path={self.base_path}, output_dir={self.output_dir}, max_depth={self.max_depth}, parallel={self.parallel}") |
| 49 | + |
| 50 | + def _scan_dir(self, path): |
| 51 | + scanned = [] |
| 52 | + try: |
| 53 | + def fast_scan(p): |
| 54 | + for entry in os.scandir(p): |
| 55 | + if entry.is_file(): |
| 56 | + yield entry.path, entry.stat().st_size |
| 57 | + elif entry.is_dir(): |
| 58 | + yield from fast_scan(entry.path) |
| 59 | + |
| 60 | + for file_path, size in fast_scan(path): |
| 61 | + scanned.append((file_path, size)) |
| 62 | + except Exception as e: |
| 63 | + logging.warning(f"Failed to scan directory: {path} - {e}") |
| 64 | + return scanned |
| 65 | + |
| 66 | + def scan(self): |
| 67 | + logging.info("Starting scan...") |
| 68 | + tic = pd.Timestamp.now() |
| 69 | + top_dirs = [entry.path for entry in os.scandir(self.base_path) if entry.is_dir()] |
| 70 | + top_files = [entry.path for entry in os.scandir(self.base_path) if entry.is_file()] |
| 71 | + |
| 72 | + # Scan top-level files first |
| 73 | + for file_path in top_files: |
| 74 | + try: |
| 75 | + size = os.path.getsize(file_path) |
| 76 | + self.file_data.append((file_path, size)) |
| 77 | + self.total_size += size |
| 78 | + except Exception as e: |
| 79 | + logging.warning(f"Failed to access top-level file: {file_path} - {e}") |
| 80 | + |
| 81 | + # Scan subdirectories |
| 82 | + # Parallel branch |
| 83 | + if self.parallel: |
| 84 | + with Pool(cpu_count()) as pool: |
| 85 | + with tqdm(total=len(top_dirs), desc="Scanning directories", dynamic_ncols=True) as pbar: |
| 86 | + for scanned in pool.imap(self._scan_dir, top_dirs): |
| 87 | + self.file_data.extend(scanned) |
| 88 | + self.total_size += sum(size for _, size in scanned) |
| 89 | + pbar.set_postfix({"Scanned": human_readable_size(self.total_size)}) |
| 90 | + pbar.update(1) |
| 91 | + else: |
| 92 | + with tqdm(top_dirs, desc="Scanning directories", dynamic_ncols=True) as pbar: |
| 93 | + for dir_path in pbar: |
| 94 | + scanned = self._scan_dir(dir_path) |
| 95 | + self.file_data.extend(scanned) |
| 96 | + self.total_size += sum(size for _, size in scanned) |
| 97 | + pbar.set_postfix({"Scanned": human_readable_size(self.total_size)}) |
| 98 | + self.scan_time = pd.Timestamp.now() |
| 99 | + logging.info( |
| 100 | + f"Scan completed in {str(pd.Timedelta(self.scan_time - tic)).split('.')[0]}." |
| 101 | + f" Total size scanned: {human_readable_size(self.total_size)}") |
| 102 | + |
| 103 | + def _build_dataframe(self): |
| 104 | + from collections import defaultdict, deque |
| 105 | + |
| 106 | + max_depth = self.max_depth |
| 107 | + max_files = self.max_files if self.max_files is not None else 10 |
| 108 | + |
| 109 | + # Step 1: Build file and directory structure |
| 110 | + file_size_map = {} |
| 111 | + dir_children = defaultdict(list) |
| 112 | + all_dirs = set() |
| 113 | + for file_path, size in self.file_data: |
| 114 | + file_path = Path(file_path) |
| 115 | + file_size_map[str(file_path)] = size |
| 116 | + parent = str(file_path.parent) |
| 117 | + dir_children[parent].append((str(file_path), size)) |
| 118 | + curr = file_path.parent |
| 119 | + while curr != self.base_path.parent: |
| 120 | + all_dirs.add(str(curr)) |
| 121 | + curr = curr.parent |
| 122 | + |
| 123 | + # Step 2: Aggregate files per directory (apply max_files) |
| 124 | + aggregated_file_data = [] |
| 125 | + for dir_path, files in dir_children.items(): |
| 126 | + if max_files and len(files) > max_files: |
| 127 | + files_sorted = sorted(files, key=lambda x: x[1], reverse=True) |
| 128 | + keep = files_sorted[:max_files] |
| 129 | + other = files_sorted[max_files:] |
| 130 | + other_size = sum(size for _, size in other) |
| 131 | + aggregated_file_data.extend(keep) |
| 132 | + if other: |
| 133 | + aggregated_file_data.append(( |
| 134 | + f"{dir_path}/Other files ({len(other)})", |
| 135 | + other_size |
| 136 | + )) |
| 137 | + else: |
| 138 | + aggregated_file_data.extend(files) |
| 139 | + |
| 140 | + # Step 3: Aggregate sizes bottom-up |
| 141 | + dir_sizes = defaultdict(int) |
| 142 | + dir_filecounts = defaultdict(int) |
| 143 | + for file_path, size in aggregated_file_data: |
| 144 | + file_path = Path(file_path) |
| 145 | + rel_parts = file_path.relative_to(self.base_path).parts |
| 146 | + if len(rel_parts) > max_depth: |
| 147 | + ancestor = self.base_path.joinpath(*rel_parts[:max_depth]) |
| 148 | + else: |
| 149 | + ancestor = file_path |
| 150 | + curr = ancestor |
| 151 | + while True: |
| 152 | + dir_sizes[str(curr)] += size |
| 153 | + dir_filecounts[str(curr)] += 1 |
| 154 | + if curr == self.base_path: |
| 155 | + break |
| 156 | + curr = curr.parent |
| 157 | + |
| 158 | + # Step 4: Build all_paths up to max_depth |
| 159 | + all_paths = set() |
| 160 | + for file_path, _ in aggregated_file_data: |
| 161 | + file_path = Path(file_path) |
| 162 | + rel_parts = file_path.relative_to(self.base_path).parts |
| 163 | + for d in range(1, min(len(rel_parts), max_depth) + 1): |
| 164 | + all_paths.add(str(self.base_path.joinpath(*rel_parts[:d]))) |
| 165 | + all_paths.add(str(self.base_path)) |
| 166 | + |
| 167 | + # Step 5: Build DataFrame |
| 168 | + path_to_id = {} |
| 169 | + data = [] |
| 170 | + next_id = 0 |
| 171 | + |
| 172 | + def depth_from_base(p): |
| 173 | + return len(Path(p).relative_to(self.base_path).parts) |
| 174 | + |
| 175 | + for p in sorted(all_paths, key=lambda x: (depth_from_base(x), x)): |
| 176 | + curr = Path(p) |
| 177 | + is_file = str(curr) in file_size_map |
| 178 | + if str(curr) == str(self.base_path): |
| 179 | + parent_id = '' |
| 180 | + label = curr.name if curr.name else str(curr) |
| 181 | + if not label or label == '.': |
| 182 | + label = str(self.base_path) |
| 183 | + else: |
| 184 | + parent_id = str(path_to_id.get(str(curr.parent), '')) |
| 185 | + label = curr.name if curr.name else str(curr) |
| 186 | + if is_file: |
| 187 | + size = file_size_map.get(str(curr), None) |
| 188 | + filecount = 1 |
| 189 | + else: |
| 190 | + size = dir_sizes.get(str(curr), 0) |
| 191 | + filecount = dir_filecounts.get(str(curr), 0) |
| 192 | + data.append({ |
| 193 | + 'id': str(next_id), |
| 194 | + 'parent': parent_id, |
| 195 | + 'label': label, |
| 196 | + 'bytes': size, |
| 197 | + 'size': human_readable_size(size) if size is not None else '', |
| 198 | + 'full_path': str(curr), |
| 199 | + 'filecount': filecount |
| 200 | + }) |
| 201 | + path_to_id[str(curr)] = next_id |
| 202 | + next_id += 1 |
| 203 | + |
| 204 | + self.df = pd.DataFrame(data) |
| 205 | + |
| 206 | + def generate_treemap(self, title: str = 'Directory Treemap', |
| 207 | + max_depth: Optional[int] = None, |
| 208 | + max_files: Optional[int] = 50): |
| 209 | + """Generate the treemap visualization. |
| 210 | + Args: |
| 211 | + title: Title of the treemap. |
| 212 | + max_depth: Maximum directory depth to display. |
| 213 | + max_files: Maximum number of files to display per directory before aggregating into "Other files". |
| 214 | + """ |
| 215 | + |
| 216 | + self.max_depth = max_depth |
| 217 | + self.max_files = max_files |
| 218 | + if self.df is None: |
| 219 | + self._build_dataframe() |
| 220 | + |
| 221 | + df = self.df.copy() |
| 222 | + df.loc[df['parent'] == '', 'label'] = df.loc[df['parent'] == '', 'full_path'].values[0] |
| 223 | + df['bytes'] = df['bytes'].fillna(0) |
| 224 | + |
| 225 | + base_dir_color = "#b3c6ff" # You can pick any color you like |
| 226 | + scan_time = self.scan_time.strftime("%Y-%m-%d %H:%M:%S") |
| 227 | + |
| 228 | + # Build color list: first node is base dir, rest are auto |
| 229 | + colors = [base_dir_color] + [None] * (len(df) - 1) |
| 230 | + |
| 231 | + fig = go.Figure(go.Treemap( |
| 232 | + labels=df['label'], |
| 233 | + parents=df['parent'], |
| 234 | + values=df['bytes'], |
| 235 | + ids=df['id'], |
| 236 | + customdata=df[['size', 'filecount']], |
| 237 | + hovertemplate='<b>%{label}</b><br>Size: %{customdata[0]}<br>Files: %{customdata[1]}<extra></extra>', |
| 238 | + branchvalues="total", |
| 239 | + marker=dict(colors=colors) |
| 240 | + )) |
| 241 | + |
| 242 | + fig.update_layout( |
| 243 | + title=title, |
| 244 | + margin=dict(b=40), # Add space for footer |
| 245 | + annotations=[ |
| 246 | + dict( |
| 247 | + text=f"Scan run: {scan_time}", |
| 248 | + showarrow=False, |
| 249 | + xref="paper", yref="paper", |
| 250 | + x=0.5, y=-0.04, xanchor="center", yanchor="bottom", |
| 251 | + font=dict(size=12, color="gray") |
| 252 | + ) |
| 253 | + ] |
| 254 | + ) |
| 255 | + self.fig = fig |
| 256 | + return fig |
| 257 | + |
| 258 | + def save_report(self, report_filename: str = "directory_treemap.html"): |
| 259 | + """Save the treemap to an HTML file. |
| 260 | + Args: |
| 261 | + report_filename: Name of the output HTML file. |
| 262 | + """ |
| 263 | + |
| 264 | + self.report_filename = report_filename |
| 265 | + |
| 266 | + if not hasattr(self, 'fig'): |
| 267 | + raise ValueError("Treemap not generated. Run generate_treemap() first.") |
| 268 | + output_path = self.output_dir / self.report_filename |
| 269 | + self.fig.write_html(output_path) |
| 270 | + print(f"Treemap saved to {output_path}") |
| 271 | + |
| 272 | + def open_report(self): |
| 273 | + """Open the saved treemap HTML file in the default web browser.""" |
| 274 | + import webbrowser |
| 275 | + output_path = self.output_dir / self.report_filename |
| 276 | + if not output_path.exists(): |
| 277 | + raise ValueError("Report file does not exist. Save the report first.") |
| 278 | + webbrowser.open(f'file://{output_path.resolve()}') |
0 commit comments