Skip to content

Commit bad8fe7

Browse files
committed
feat: add parallel execution for remote file operations
Implements ThreadPoolExecutor-based parallel processing for remote file operations with configurable worker pool size (default: 5 workers). Parallel execution is now used for: - File validation (validate_all_config_files) - Resource downloads (process_resources) - File downloads (process_file_downloads) - Hook file downloads (download_hook_files) Provides CLAUDE_SEQUENTIAL_MODE environment variable for backward compatibility, allowing users to disable parallel execution and fall back to sequential processing when needed. Includes comprehensive test suite (29 tests) covering parallel execution, error handling, sequential mode fallback, and integration with existing functionality.
1 parent 2f9a18b commit bad8fe7

File tree

2 files changed

+629
-24
lines changed

2 files changed

+629
-24
lines changed

scripts/setup_environment.py

Lines changed: 201 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# ///
1111

1212
import argparse
13+
import concurrent.futures
1314
import contextlib
1415
import glob as glob_module
1516
import json
@@ -25,9 +26,11 @@
2526
import time
2627
import urllib.error
2728
import urllib.parse
29+
from collections.abc import Callable
2830
from pathlib import Path
2931
from typing import TYPE_CHECKING
3032
from typing import Any
33+
from typing import TypeVar
3134
from typing import cast
3235
from urllib.request import Request
3336
from urllib.request import urlopen
@@ -87,6 +90,131 @@ def debug_log(message: str) -> None:
8790
print(f' [DEBUG] {message}', file=sys.stderr)
8891

8992

93+
# Parallel execution helpers
94+
# Type variable for generic parallel execution
95+
T = TypeVar('T')
96+
R = TypeVar('R')
97+
98+
# Default number of parallel workers (optimal for GitHub API rate limiting)
99+
DEFAULT_PARALLEL_WORKERS = 5
100+
101+
102+
def is_parallel_mode_enabled() -> bool:
103+
"""Check if parallel execution is enabled.
104+
105+
Returns:
106+
True if parallel mode is enabled (default), False if CLAUDE_SEQUENTIAL_MODE=1
107+
"""
108+
sequential_mode = os.environ.get('CLAUDE_SEQUENTIAL_MODE', '').lower()
109+
return sequential_mode not in ('1', 'true', 'yes')
110+
111+
112+
def execute_parallel(
113+
items: list[T],
114+
func: Callable[[T], R],
115+
max_workers: int = DEFAULT_PARALLEL_WORKERS,
116+
) -> list[R]:
117+
"""Execute a function on items in parallel with error isolation.
118+
119+
Processes items using ThreadPoolExecutor when parallel mode is enabled,
120+
or sequentially when CLAUDE_SEQUENTIAL_MODE=1.
121+
122+
Args:
123+
items: List of items to process
124+
func: Function to apply to each item
125+
max_workers: Maximum number of parallel workers (default: 5)
126+
127+
Returns:
128+
List of results in the same order as input items.
129+
If an item raises an exception, that exception is stored in the result list
130+
and re-raised after all items are processed.
131+
"""
132+
import operator
133+
134+
if not items:
135+
return []
136+
137+
# Sequential mode fallback
138+
if not is_parallel_mode_enabled():
139+
debug_log('Sequential mode enabled, processing items sequentially')
140+
return [func(item) for item in items]
141+
142+
# Parallel execution
143+
debug_log(f'Parallel mode enabled, processing {len(items)} items with {max_workers} workers')
144+
results_with_index: list[tuple[int, R | BaseException]] = []
145+
146+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
147+
# Submit all tasks with their index for ordering
148+
future_to_index: dict[concurrent.futures.Future[R], int] = {
149+
executor.submit(func, item): idx for idx, item in enumerate(items)
150+
}
151+
152+
# Collect results as they complete
153+
for future in concurrent.futures.as_completed(future_to_index):
154+
idx = future_to_index[future]
155+
try:
156+
result = future.result()
157+
results_with_index.append((idx, result))
158+
except Exception as task_exc:
159+
# Store exception to maintain order and allow partial results
160+
results_with_index.append((idx, task_exc))
161+
162+
# Sort by original index to maintain order
163+
results_with_index.sort(key=operator.itemgetter(0))
164+
165+
# Extract results, re-raising any exceptions
166+
final_results: list[R] = []
167+
exceptions: list[tuple[int, BaseException]] = []
168+
for idx, result_or_exc in results_with_index:
169+
if isinstance(result_or_exc, BaseException):
170+
exceptions.append((idx, result_or_exc))
171+
else:
172+
final_results.append(result_or_exc)
173+
174+
# If there were exceptions, raise the first one after logging all
175+
if exceptions:
176+
for exc_idx, stored_exc in exceptions:
177+
debug_log(f'Item {exc_idx} raised exception: {stored_exc}')
178+
# Re-raise the first exception
179+
raise exceptions[0][1]
180+
181+
return final_results
182+
183+
184+
def execute_parallel_safe(
185+
items: list[T],
186+
func: Callable[[T], R],
187+
default_on_error: R,
188+
max_workers: int = DEFAULT_PARALLEL_WORKERS,
189+
) -> list[R]:
190+
"""Execute a function on items in parallel with error handling.
191+
192+
Unlike execute_parallel, this function catches exceptions and returns
193+
a default value for failed items, allowing partial success.
194+
195+
Args:
196+
items: List of items to process
197+
func: Function to apply to each item
198+
default_on_error: Value to return for items that raise exceptions
199+
max_workers: Maximum number of parallel workers (default: 5)
200+
201+
Returns:
202+
List of results in the same order as input items.
203+
Failed items return default_on_error instead of their result.
204+
"""
205+
if not items:
206+
return []
207+
208+
def safe_func(item: T) -> R:
209+
try:
210+
return func(item)
211+
except Exception as exc:
212+
debug_log(f'Item processing failed: {exc}')
213+
return default_on_error
214+
215+
return execute_parallel(items, safe_func, max_workers)
216+
217+
90218
# Windows UAC elevation helper functions
91219
def is_admin() -> bool:
92220
"""Check if running with admin privileges on Windows.
@@ -1361,21 +1489,38 @@ def validate_all_config_files(
13611489
full_path = str(Path(resolved_base) / skill_file_item)
13621490
files_to_check.append(('skill', full_path, full_path, False))
13631491

1364-
# Validate each file
1492+
# Validate each file using parallel execution
13651493
info(f'Validating {len(files_to_check)} files...')
1366-
all_valid = True
13671494

1368-
for file_type, original_path, resolved_path, is_remote in files_to_check:
1369-
# Use FileValidator for unified validation with per-URL authentication
1495+
def validate_single_file(
1496+
file_info: tuple[str, str, str, bool],
1497+
) -> tuple[str, str, bool, str]:
1498+
"""Validate a single file and return result tuple."""
1499+
file_type, original_path, resolved_path, is_remote = file_info
13701500
is_valid, method = validator.validate(resolved_path, is_remote)
1371-
results.append((file_type, original_path, is_valid, method))
1501+
return (file_type, original_path, is_valid, method)
1502+
1503+
# Execute validation in parallel (or sequential if CLAUDE_SEQUENTIAL_MODE=1)
1504+
results = execute_parallel(files_to_check, validate_single_file)
13721505

1506+
# Process results and print status messages
1507+
all_valid = True
1508+
for file_type, original_path, is_valid, method in results:
13731509
if is_valid:
1510+
# Find the resolved_path for this item (for error messages)
1511+
is_remote = method != 'Local'
13741512
if is_remote:
13751513
info(f' [OK] {file_type}: {original_path} (remote, validated via {method})')
13761514
else:
13771515
info(f' [OK] {file_type}: {original_path} (local file exists)')
13781516
else:
1517+
# Find resolved_path for error message
1518+
resolved_path = original_path
1519+
for ft, op, rp, _ir in files_to_check:
1520+
if ft == file_type and op == original_path:
1521+
resolved_path = rp
1522+
break
1523+
is_remote = method != 'Local'
13791524
if is_remote:
13801525
error(f' [FAIL] {file_type}: {original_path} (remote, not accessible)')
13811526
else:
@@ -3191,6 +3336,8 @@ def process_resources(
31913336
) -> bool:
31923337
"""Process resources (download from URL or copy from local) based on configuration.
31933338
3339+
Uses parallel execution when CLAUDE_SEQUENTIAL_MODE is not set.
3340+
31943341
Args:
31953342
resources: List of resource paths from config
31963343
destination_dir: Directory to save resources
@@ -3208,14 +3355,23 @@ def process_resources(
32083355

32093356
info(f'Processing {resource_type}...')
32103357

3358+
# Prepare download tasks
3359+
download_tasks: list[tuple[str, Path]] = []
32113360
for resource in resources:
32123361
# Strip query parameters from URL to get clean filename
32133362
clean_resource = resource.split('?')[0] if '?' in resource else resource
32143363
filename = Path(clean_resource).name
32153364
destination = destination_dir / filename
3216-
handle_resource(resource, destination, config_source, base_url, auth_param)
3365+
download_tasks.append((resource, destination))
32173366

3218-
return True
3367+
def download_single_resource(task: tuple[str, Path]) -> bool:
3368+
"""Download a single resource and return success status."""
3369+
resource, destination = task
3370+
return handle_resource(resource, destination, config_source, base_url, auth_param)
3371+
3372+
# Execute downloads in parallel (or sequential if CLAUDE_SEQUENTIAL_MODE=1)
3373+
results = execute_parallel_safe(download_tasks, download_single_resource, False)
3374+
return all(results)
32193375

32203376

32213377
def process_file_downloads(
@@ -3228,6 +3384,7 @@ def process_file_downloads(
32283384
32293385
Downloads files from URLs or copies from local paths to specified destinations.
32303386
Supports cross-platform path expansion using ~ and environment variables.
3387+
Uses parallel execution when CLAUDE_SEQUENTIAL_MODE is not set.
32313388
32323389
Args:
32333390
file_specs: List of file specifications with 'source' and 'dest' keys.
@@ -3251,8 +3408,10 @@ def process_file_downloads(
32513408
return True
32523409

32533410
info(f'Processing {len(file_specs)} file downloads...')
3254-
success_count = 0
3255-
failed_count = 0
3411+
3412+
# Pre-validate file specs and prepare download tasks
3413+
valid_downloads: list[tuple[str, Path]] = []
3414+
invalid_count = 0
32563415

32573416
for file_spec in file_specs:
32583417
source = file_spec.get('source')
@@ -3266,7 +3425,7 @@ def process_file_downloads(
32663425
warning(f'Invalid file specification: missing dest ({file_spec})')
32673426
else:
32683427
warning(f'Invalid file specification: {file_spec} (missing source or dest)')
3269-
failed_count += 1
3428+
invalid_count += 1
32703429
continue
32713430

32723431
# Expand destination path (~ and environment variables)
@@ -3284,12 +3443,21 @@ def process_file_downloads(
32843443
filename = Path(clean_source).name
32853444
dest_path = dest_path / filename
32863445

3287-
# Use existing handle_resource function for download/copy
3288-
# This handles: URL downloads, local file copying, overwriting, directory creation
3289-
if handle_resource(str(source), dest_path, config_source, base_url, auth_param):
3290-
success_count += 1
3291-
else:
3292-
failed_count += 1
3446+
valid_downloads.append((str(source), dest_path))
3447+
3448+
def download_single_file(download_info: tuple[str, Path]) -> bool:
3449+
"""Download a single file and return success status."""
3450+
source, dest_path = download_info
3451+
return handle_resource(source, dest_path, config_source, base_url, auth_param)
3452+
3453+
# Execute downloads in parallel (or sequential if CLAUDE_SEQUENTIAL_MODE=1)
3454+
if valid_downloads:
3455+
download_results = execute_parallel_safe(valid_downloads, download_single_file, False)
3456+
success_count = sum(1 for result in download_results if result)
3457+
failed_count = len(download_results) - success_count + invalid_count
3458+
else:
3459+
success_count = 0
3460+
failed_count = invalid_count
32933461

32943462
# Print summary
32953463
print() # Blank line for readability
@@ -3850,6 +4018,8 @@ def download_hook_files(
38504018
) -> bool:
38514019
"""Download hook files from configuration.
38524020
4021+
Uses parallel execution when CLAUDE_SEQUENTIAL_MODE is not set.
4022+
38534023
Args:
38544024
hooks: Hooks configuration dictionary with 'files' key
38554025
claude_user_dir: Path to Claude user directory
@@ -3866,23 +4036,30 @@ def download_hook_files(
38664036
info('No hook files to download')
38674037
return True
38684038

4039+
if not config_source:
4040+
error('No config source provided for hook files')
4041+
return False
4042+
38694043
hooks_dir = claude_user_dir / 'hooks'
38704044
hooks_dir.mkdir(parents=True, exist_ok=True)
38714045

4046+
# Prepare download tasks
4047+
download_tasks: list[tuple[str, Path]] = []
38724048
for file in hook_files:
38734049
# Strip query parameters from URL to get clean filename
38744050
clean_file = file.split('?')[0] if '?' in file else file
38754051
filename = Path(clean_file).name
38764052
destination = hooks_dir / filename
3877-
# Handle hook files (download or copy)
3878-
if config_source:
3879-
handle_resource(file, destination, config_source, base_url, auth_param)
3880-
else:
3881-
# This shouldn't happen, but handle gracefully
3882-
error(f'No config source provided for hook file: {file}')
3883-
return False
4053+
download_tasks.append((file, destination))
38844054

3885-
return True
4055+
def download_single_hook(task: tuple[str, Path]) -> bool:
4056+
"""Download a single hook file and return success status."""
4057+
file, destination = task
4058+
return handle_resource(file, destination, config_source, base_url, auth_param)
4059+
4060+
# Execute downloads in parallel (or sequential if CLAUDE_SEQUENTIAL_MODE=1)
4061+
results = execute_parallel_safe(download_tasks, download_single_hook, False)
4062+
return all(results)
38864063

38874064

38884065
def create_additional_settings(

0 commit comments

Comments
 (0)