diff --git a/.gitignore b/.gitignore index d9a3cfe..8beee89 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ dist/ .pypirc .vscode/ .DS_Store +.codesouler/ diff --git a/build.py b/build.py new file mode 100644 index 0000000..12d7801 --- /dev/null +++ b/build.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +""" +Build script for handling version information +""" + +import re +from pathlib import Path + + +def get_version(): + """Get version from pyproject.toml""" + pyproject_path = Path("pyproject.toml") + if pyproject_path.exists(): + with open(pyproject_path, "r", encoding="utf-8") as f: + content = f.read() + match = re.search(r'version = "([^"]+)"', content) + if match: + return match.group(1) + return "0.1.0" # default version + + +def update_version_in_init(): + """Update version in __init__.py""" + init_path = Path("pycsghub/__init__.py") + if init_path.exists(): + with open(init_path, "r", encoding="utf-8") as f: + content = f.read() + + # Update version + new_content = re.sub( + r'__version__ = "[^"]*"', + f'__version__ = "{get_version()}"', + content + ) + + with open(init_path, "w", encoding="utf-8") as f: + f.write(new_content) + print(f"Updated version to {get_version()} in __init__.py") + + +if __name__ == "__main__": + update_version_in_init() diff --git a/examples/download_dataset.py b/examples/download_dataset.py index e7a9ccb..d0788a8 100644 --- a/examples/download_dataset.py +++ b/examples/download_dataset.py @@ -1,4 +1,5 @@ from pycsghub.snapshot_download import snapshot_download + # token = "your access token" token = None diff --git a/examples/download_file.py b/examples/download_file.py index c16d690..c3040e5 100644 --- a/examples/download_file.py +++ b/examples/download_file.py @@ -1,4 +1,5 @@ from pycsghub.file_download import file_download + # token = "your access token" token = None @@ -7,11 +8,11 @@ repo_id = 'OpenCSG/csg-wukong-1B' local_dir = "/Users/hhwang/temp/wukong" result = file_download( - repo_id, - file_name='README.md', - local_dir=local_dir, - endpoint=endpoint, - token=token, + repo_id, + file_name='README.md', + local_dir=local_dir, + endpoint=endpoint, + token=token, repo_type=repo_type) print(f"Save file to {result}") diff --git a/examples/download_file_parallel.py b/examples/download_file_parallel.py new file mode 100644 index 0000000..d553b0a --- /dev/null +++ b/examples/download_file_parallel.py @@ -0,0 +1,60 @@ +import logging + +from pycsghub.file_download import file_download, snapshot_download_parallel + +# Configure logging level +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# token = "your access token" +token = None + +endpoint = "https://hub.opencsg.com" +repo_type = "model" +repo_id = 'OpenCSG/csg-wukong-1B' +local_dir = "/Users/hhwang/temp/wukong" + +print("=== Single-file multi-threaded download example ===") +result = file_download( + repo_id, + file_name='README.md', + local_dir=local_dir, + endpoint=endpoint, + token=token, + repo_type=repo_type, + max_workers=4, + use_parallel=True +) + +print(f"Single-file multi-threaded downloaded ,save to: {result}") + +print("\n=== Example of multi-threaded download for the entire repository ===") +cache_dir = "/Users/hhwang/temp/" +allow_patterns = ["*.json", "*.md", "*.txt"] + +result = snapshot_download_parallel( + repo_id, + repo_type=repo_type, + cache_dir=cache_dir, + endpoint=endpoint, + token=token, + allow_patterns=allow_patterns, + max_workers=6, + use_parallel=True, + verbose=True +) + +print(f"Repository downloaded, save to: {result}") + +print("\n=== Example of single-threaded download comparison ===") + +result_single = file_download( + repo_id, + file_name='README.md', + local_dir=local_dir, + endpoint=endpoint, + token=token, + repo_type=repo_type, + use_parallel=False +) + +print(f"Single-threaded downloaded, save to: {result_single}") diff --git a/examples/download_model.py b/examples/download_model.py index b7104f5..ddb0fca 100644 --- a/examples/download_model.py +++ b/examples/download_model.py @@ -1,4 +1,5 @@ from pycsghub.snapshot_download import snapshot_download + # token = "your access token" token = None @@ -10,13 +11,12 @@ ignore_patterns = ["tokenizer.json"] result = snapshot_download( - repo_id, - repo_type=repo_type, - local_dir=local_dir, - endpoint=endpoint, + repo_id, + repo_type=repo_type, + local_dir=local_dir, + endpoint=endpoint, token=token, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns) print(f"Save model to {result}") - diff --git a/examples/download_with_custom_progress.py b/examples/download_with_custom_progress.py new file mode 100644 index 0000000..235e2f1 --- /dev/null +++ b/examples/download_with_custom_progress.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +""" +Demonstrate how to use custom progress display for model download +""" + +import time +from datetime import datetime +from pycsghub import snapshot_download + + +class CustomProgressTracker: + """Custom progress tracker""" + + def __init__(self): + self.start_time = None + self.last_update_time = None + + def progress_callback(self, progress_info): + """Custom progress callback function""" + current_time = datetime.now() + + if self.start_time is None: + self.start_time = current_time + self.last_update_time = current_time + + time_since_last = (current_time - self.last_update_time).total_seconds() + if time_since_last >= 1.0 or progress_info['current_downloaded'] == progress_info['total_files']: + self._print_progress(progress_info, current_time) + self.last_update_time = current_time + + def _print_progress(self, progress_info, current_time): + """Print progress information""" + total_files = progress_info['total_files'] + current_downloaded = progress_info['current_downloaded'] + success_count = progress_info['success_count'] + failed_count = progress_info['failed_count'] + remaining_count = progress_info['remaining_count'] + + if total_files > 0: + progress_percent = (current_downloaded / total_files) * 100 + else: + progress_percent = 0 + + elapsed_time = (current_time - self.start_time).total_seconds() + + if current_downloaded > 0: + avg_time_per_file = elapsed_time / current_downloaded + estimated_remaining = avg_time_per_file * remaining_count + else: + estimated_remaining = 0 + + bar_length = 30 + filled_length = int(bar_length * progress_percent / 100) + bar = '█' * filled_length + '░' * (bar_length - filled_length) + + print(f"\r[{bar}] {progress_percent:5.1f}% | " + f"Downloaded: {current_downloaded}/{total_files} | " + f"Success: {success_count} | " + f"Failed: {failed_count} | " + f"Remaining: {remaining_count} | " + f"Elapsed: {elapsed_time:.1f}s | " + f"Estimated remaining: {estimated_remaining:.1f}s", end='', flush=True) + + # If download completed, newline + if current_downloaded == total_files: + print() # Newline + + +def main(): + """ + Main function - Demonstrate custom progress tracking + """ + print("Start demonstrating custom progress tracking for model download...") + + progress_tracker = CustomProgressTracker() + + repo_id = "OpenCSG/csg-wukong-1B" + + try: + local_path = snapshot_download( + repo_id=repo_id, + progress_callback=progress_tracker.progress_callback, + verbose=False, + use_parallel=True, + max_workers=4 + ) + + print(f"\n✅ Download completed! Model saved to: {local_path}") + + except Exception as e: + print(f"\n❌ Error during download: {e}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/download_with_progress.py b/examples/download_with_progress.py new file mode 100644 index 0000000..05bb190 --- /dev/null +++ b/examples/download_with_progress.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +""" +Demonstrate how to use the progress callback feature of snapshot_download +""" + +import time +from pycsghub import snapshot_download + + +def progress_callback(progress_info): + """ + Progress callback function + Receives download progress information and prints it + """ + print(f"\n=== Download progress update ===") + print(f"Total files: {progress_info['total_files']}") + print(f"Current downloaded: {progress_info['current_downloaded']}") + print(f"Success count: {progress_info['success_count']}") + print(f"Failed count: {progress_info['failed_count']}") + print(f"Remaining count: {progress_info['remaining_count']}") + + if progress_info['successful_files']: + print(f"Recently successful downloaded file: {progress_info['successful_files'][-1]}") + + if progress_info['remaining_files']: + print(f"Next file to download: {progress_info['remaining_files'][0]}") + + if progress_info['total_files'] > 0: + progress_percent = (progress_info['current_downloaded'] / progress_info['total_files']) * 100 + print(f"Overall progress: {progress_percent:.1f}%") + + print("=" * 30) + + +def main(): + """ + Main function - Demonstrate download with progress callback + """ + print("Start demonstrating download with progress callback...") + + # Example model ID (please replace with actual model ID) + repo_id = "example/model" + + try: + # Use progress callback to download model + local_path = snapshot_download( + repo_id=repo_id, + progress_callback=progress_callback, + verbose=True, + use_parallel=True, + max_workers=4 + ) + + print(f"\nDownload completed! Model saved to: {local_path}") + + except Exception as e: + print(f"Error during download: {e}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/run_finetune_bert.py b/examples/run_finetune_bert.py index 7e7b7f5..d44e1d0 100644 --- a/examples/run_finetune_bert.py +++ b/examples/run_finetune_bert.py @@ -1,12 +1,12 @@ from typing import Any -import pandas as pd +import pandas as pd from transformers import DataCollatorWithPadding -from transformers import TrainingArguments from transformers import Trainer +from transformers import TrainingArguments +from pycsghub.repo_reader import AutoModelForSequenceClassification, AutoTokenizer from pycsghub.repo_reader import load_dataset -from pycsghub.repo_reader import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig model_id_or_path = "wanghh2000/bert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) @@ -18,29 +18,34 @@ access_token = None raw_datasets = load_dataset(dsPath, dsName, token=access_token) + def get_data_proprocess() -> Any: - def preprocess_function(examples: pd.DataFrame): + def preprocess_function(examples: pd.DataFrame): ret = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=100) ret = {**examples, **ret} return pd.DataFrame.from_dict(ret) + return preprocess_function + train_dataset = raw_datasets["train"].select(range(20)).map(get_data_proprocess(), batched=True) eval_dataset = raw_datasets["validation"].select(range(20)).map(get_data_proprocess(), batched=True) + def data_collator() -> Any: data_collator = DataCollatorWithPadding(tokenizer=tokenizer) return data_collator + outputDir = "/Users/hhwang/temp/ff" args = TrainingArguments( outputDir, evaluation_strategy="steps", save_strategy="steps", logging_strategy="steps", - logging_steps = 2, - save_steps = 10, - eval_steps = 2, + logging_steps=2, + save_steps=10, + eval_steps=2, learning_rate=2e-5, per_device_train_batch_size=4, per_device_eval_batch_size=4, diff --git a/examples/run_wukong_inference.py b/examples/run_wukong_inference.py index 48cc3b8..c366ef7 100644 --- a/examples/run_wukong_inference.py +++ b/examples/run_wukong_inference.py @@ -1,4 +1,3 @@ -import os from pycsghub.repo_reader import AutoModelForCausalLM, AutoTokenizer mid = 'OpenCSG/csg-wukong-1B' @@ -7,4 +6,4 @@ inputs = tokenizer.encode("Write a short story", return_tensors="pt") outputs = model.generate(inputs) -print('result: ',tokenizer.batch_decode(outputs)) +print('result: ', tokenizer.batch_decode(outputs)) diff --git a/examples/upload_repo.py b/examples/upload_repo.py index e52da73..b9d9f31 100644 --- a/examples/upload_repo.py +++ b/examples/upload_repo.py @@ -11,4 +11,4 @@ repo_type="dataset", ) -r.upload() \ No newline at end of file +r.upload() diff --git a/launch.py b/launch.py index b5c1651..5c4596a 100644 --- a/launch.py +++ b/launch.py @@ -1,6 +1,8 @@ import re import sys + from pycsghub.cli import app + if __name__ == '__main__': sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(app()) \ No newline at end of file + sys.exit(app()) diff --git a/pycsghub/_token.py b/pycsghub/_token.py index 3ec0552..bedde45 100644 --- a/pycsghub/_token.py +++ b/pycsghub/_token.py @@ -1,11 +1,12 @@ -from typing import Optional -from pathlib import Path import os +from pathlib import Path +from typing import Optional + from pycsghub.constants import CSGHUB_TOKEN_PATH def _get_token_from_environment() -> Optional[str]: - return _clean_token(os.environ.get("CSGHUB_TOKEN")) # apk key + return _clean_token(os.environ.get("CSGHUB_TOKEN")) # apk key def _get_token_from_file() -> Optional[str]: @@ -22,4 +23,4 @@ def _clean_token(token: Optional[str]) -> Optional[str]: """ if token is None: return None - return token.replace("\r", "").replace("\n", "").strip() or None \ No newline at end of file + return token.replace("\r", "").replace("\n", "").strip() or None diff --git a/pycsghub/cache.py b/pycsghub/cache.py index e87acdf..331e140 100644 --- a/pycsghub/cache.py +++ b/pycsghub/cache.py @@ -2,9 +2,11 @@ import os import pickle import tempfile +import threading from shutil import move, rmtree from typing import Dict, Union + class FileSystemCache(object): KEY_FILE_NAME = '.msc' MODEL_META_FILE_NAME = '.mdl' @@ -14,9 +16,9 @@ class FileSystemCache(object): """ def __init__( - self, - cache_root_location: str, - **kwargs, + self, + cache_root_location: str, + **kwargs, ): """Base file system cache interface. @@ -24,31 +26,71 @@ def __init__( cache_root_location (str): The root location to store files. kwargs(dict): The keyword arguments. """ - os.makedirs(cache_root_location, exist_ok=True) - self.cache_root_location = cache_root_location - self.load_cache() + try: + if os.name == 'nt' and len(os.path.abspath(cache_root_location)) > 240: + print(f"Warning: Cache path too long for Windows: {cache_root_location}") + try: + import win32api + short_path = win32api.GetShortPathName(cache_root_location) + if len(short_path) <= 240: + cache_root_location = short_path + except ImportError: + parts = cache_root_location.split(os.sep) + while len(cache_root_location) > 240 and len(parts) > 1: + parts.pop(1) + cache_root_location = os.sep.join(parts) + + os.makedirs(cache_root_location, exist_ok=True) + self.cache_root_location = cache_root_location + self._lock = threading.RLock() + self.load_cache() + except (OSError, IOError) as e: + raise RuntimeError(f"Failed to initialize cache at {cache_root_location}: {e}") def get_root_location(self): return self.cache_root_location def load_cache(self): + """Load cache metadata with error handling.""" self.cached_files = [] cache_keys_file_path = os.path.join(self.cache_root_location, FileSystemCache.KEY_FILE_NAME) if os.path.exists(cache_keys_file_path): - with open(cache_keys_file_path, 'rb') as f: - self.cached_files = pickle.load(f) + try: + with open(cache_keys_file_path, 'rb') as f: + self.cached_files = pickle.load(f) + except (pickle.PickleError, IOError, EOFError) as e: + print(f"Warning: Cache file corrupted, recreating: {e}") + self.cached_files = [] + try: + os.remove(cache_keys_file_path) + except OSError: + pass def save_cached_files(self): - """Save cache metadata.""" - # save new meta to tmp and move to KEY_FILE_NAME - cache_keys_file_path = os.path.join(self.cache_root_location, - FileSystemCache.KEY_FILE_NAME) - # TODO: Sync file write - fd, fn = tempfile.mkstemp() - with open(fd, 'wb') as f: - pickle.dump(self.cached_files, f) - move(fn, cache_keys_file_path) + """Save cache metadata with atomic operation.""" + with self._lock: + cache_keys_file_path = os.path.join(self.cache_root_location, + FileSystemCache.KEY_FILE_NAME) + try: + if os.name == 'nt': + temp_dir = tempfile.gettempdir() + else: + temp_dir = self.cache_root_location + + fd, fn = tempfile.mkstemp(dir=temp_dir, suffix='.tmp') + try: + with os.fdopen(fd, 'wb') as f: + pickle.dump(self.cached_files, f) + move(fn, cache_keys_file_path) + except (IOError, OSError) as e: + try: + os.remove(fn) + except OSError: + pass + raise RuntimeError(f"Failed to save cache metadata: {e}") + except Exception as e: + raise RuntimeError(f"Failed to save cache metadata: {e}") def get_file(self, key): """Check the key is in the cache, if exist, return the file, otherwise return None. @@ -79,15 +121,16 @@ def remove_key(self, key): Args: key (dict): The cache key. """ - if key in self.cached_files: - self.cached_files.remove(key) - self.save_cached_files() + with self._lock: + if key in self.cached_files: + self.cached_files.remove(key) + self.save_cached_files() def exists(self, key): + """Check if key exists in cache with exact match.""" for cache_file in self.cached_files: if cache_file == key: return True - return False def clear_cache(self): @@ -95,8 +138,13 @@ def clear_cache(self): In the case of multiple cache locations, this clears only the last one, which is assumed to be the read/write one. """ - rmtree(self.cache_root_location) - self.load_cache() + with self._lock: + try: + rmtree(self.cache_root_location) + os.makedirs(self.cache_root_location, exist_ok=True) + self.load_cache() + except (OSError, IOError) as e: + raise RuntimeError(f"Failed to clear cache: {e}") def hash_name(self, key): return hashlib.sha256(key.encode()).hexdigest() @@ -111,7 +159,7 @@ class ModelFileSystemCache(FileSystemCache): def __init__(self, cache_root, owner=None, name=None, local_dir: Union[str, None] = None): """Put file to the cache Args: - cache_root(`str`): The csghub local cache root(default: ~/.cache/csghub/) + cache_root(`str`): The csghub local cache root(default: current directory) owner(`str`): The model owner. name('str'): The name of the model Returns: @@ -121,18 +169,27 @@ def __init__(self, cache_root, owner=None, name=None, local_dir: Union[str, None model_id = {owner}/{name} """ - if owner is None or name is None: - # get model meta from - super().__init__(os.path.join(cache_root)) - self.load_model_meta() - else: - super().__init__(os.path.join(cache_root, owner, name)) - self.model_meta = { - FileSystemCache.MODEL_META_MODEL_ID: '%s/%s' % (owner, name) - } - self.save_model_meta() - self.cached_model_revision = self.load_model_version() - self.local_dir = local_dir + try: + if owner is None or name is None: + super().__init__(os.path.join(cache_root)) + self.load_model_meta() + else: + if os.name == 'nt': + invalid_chars = '<>:"|?*' + for char in invalid_chars: + owner = owner.replace(char, '_') + name = name.replace(char, '_') + + cache_path = os.path.join(cache_root, owner, name) + super().__init__(cache_path) + self.model_meta = { + FileSystemCache.MODEL_META_MODEL_ID: '%s/%s' % (owner, name) + } + self.save_model_meta() + self.cached_model_revision = self.load_model_version() + self.local_dir = local_dir + except Exception as e: + raise RuntimeError(f"Failed to initialize ModelFileSystemCache: {e}") def get_root_location(self): if self.local_dir is not None: @@ -141,39 +198,57 @@ def get_root_location(self): return self.cache_root_location def load_model_meta(self): + """Load model metadata with error handling.""" meta_file_path = os.path.join(self.cache_root_location, FileSystemCache.MODEL_META_FILE_NAME) if os.path.exists(meta_file_path): - with open(meta_file_path, 'rb') as f: - self.model_meta = pickle.load(f) + try: + with open(meta_file_path, 'rb') as f: + self.model_meta = pickle.load(f) + except (pickle.PickleError, IOError, EOFError) as e: + print(f"Warning: Model meta file corrupted, using default: {e}") + self.model_meta = {FileSystemCache.MODEL_META_MODEL_ID: 'unknown'} else: self.model_meta = {FileSystemCache.MODEL_META_MODEL_ID: 'unknown'} def load_model_version(self): + """Load model version with error handling.""" model_version_file_path = os.path.join( self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME) if os.path.exists(model_version_file_path): - with open(model_version_file_path, 'r') as f: - return f.read().strip() + try: + with open(model_version_file_path, 'r') as f: + return f.read().strip() + except (IOError, UnicodeDecodeError) as e: + print(f"Warning: Model version file corrupted: {e}") + return None else: return None def save_model_version(self, revision_info: Dict): - model_version_file_path = os.path.join( - self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME) - with open(model_version_file_path, 'w') as f: - version_info_str = 'Revision:%s' % ( - revision_info['Revision']) - f.write(version_info_str) + """Save model version with error handling.""" + try: + model_version_file_path = os.path.join( + self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME) + with open(model_version_file_path, 'w') as f: + version_info_str = 'Revision:%s' % ( + revision_info['Revision']) + f.write(version_info_str) + except (IOError, OSError) as e: + raise RuntimeError(f"Failed to save model version: {e}") def get_model_id(self): return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID] def save_model_meta(self): - meta_file_path = os.path.join(self.cache_root_location, - FileSystemCache.MODEL_META_FILE_NAME) - with open(meta_file_path, 'wb') as f: - pickle.dump(self.model_meta, f) + """Save model metadata with error handling.""" + try: + meta_file_path = os.path.join(self.cache_root_location, + FileSystemCache.MODEL_META_FILE_NAME) + with open(meta_file_path, 'wb') as f: + pickle.dump(self.model_meta, f) + except (IOError, OSError, pickle.PickleError) as e: + raise RuntimeError(f"Failed to save model metadata: {e}") def get_file_by_path(self, file_path): """Retrieve the cache if there is file match the path. @@ -207,7 +282,7 @@ def get_file_by_path_and_commit_id(self, file_path, commit_id): """ for cached_file in self.cached_files: if file_path == cached_file['Path'] and \ - (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])): + (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])): cached_file_path = os.path.join(self.cache_root_location, cached_file['Path']) if os.path.exists(cached_file_path): @@ -247,7 +322,7 @@ def __get_cache_key(self, model_file_info): return cache_key def exists(self, model_file_info): - """Check the file is cached or not. + """Check the file is cached or not with improved version matching. Args: model_file_info (CachedFileInfo): The cached file info @@ -257,20 +332,32 @@ def exists(self, model_file_info): """ key = self.__get_cache_key(model_file_info) is_exists = False + for cached_key in self.cached_files: - if cached_key['Path'] == key['Path'] and ( - cached_key['Revision'].startswith(key['Revision']) - or key['Revision'].startswith(cached_key['Revision'])): - is_exists = True - break + if cached_key['Path'] == key['Path']: + if cached_key['Revision'] == key['Revision']: + is_exists = True + break + elif (len(cached_key['Revision']) >= 6 and + cached_key['Revision'].startswith(key['Revision'])) or \ + (len(key['Revision']) >= 6 and + key['Revision'].startswith(cached_key['Revision'])): + is_exists = True + break + file_path = os.path.join(self.cache_root_location, model_file_info['Path']) if self.local_dir is not None: file_path = os.path.join(self.local_dir, model_file_info['Path']) + if is_exists: if os.path.exists(file_path): return True else: - self.remove_key(model_file_info) # someone may manual delete the file + for cached_file in self.cached_files: + if (cached_file['Path'] == key['Path'] and + cached_file['Revision'] == key['Revision']): + self.remove_key(cached_file) + break return False def remove_if_exists(self, model_file_info): @@ -279,6 +366,7 @@ def remove_if_exists(self, model_file_info): Args: model_file_info (ModelFileInfo): The model file information from server. """ + key = self.__get_cache_key(model_file_info) for cached_file in self.cached_files: if cached_file['Path'] == model_file_info['Path']: self.remove_key(cached_file) @@ -286,7 +374,10 @@ def remove_if_exists(self, model_file_info): if self.local_dir is not None: file_path = os.path.join(self.local_dir, cached_file['Path']) if os.path.exists(file_path): - os.remove(file_path) + try: + os.remove(file_path) + except OSError as e: + print(f"Warning: Failed to remove cached file {file_path}: {e}") break def put_file(self, model_file_info, model_file_location): @@ -299,16 +390,26 @@ def put_file(self, model_file_info, model_file_location): Returns: str: The location of the cached file. """ - self.remove_if_exists(model_file_info) - cache_key = self.__get_cache_key(model_file_info) - cache_full_path = os.path.join(self.cache_root_location, cache_key['Path']) - if self.local_dir is not None: - cache_full_path = os.path.join(self.local_dir, cache_key['Path']) - cache_file_dir = os.path.dirname(cache_full_path) - if not os.path.exists(cache_file_dir): - os.makedirs(cache_file_dir, exist_ok=True) - # We can't make operation transaction - move(model_file_location, cache_full_path) - self.cached_files.append(cache_key) - self.save_cached_files() - return cache_full_path \ No newline at end of file + try: + self.remove_if_exists(model_file_info) + cache_key = self.__get_cache_key(model_file_info) + cache_full_path = os.path.join(self.cache_root_location, cache_key['Path']) + if self.local_dir is not None: + cache_full_path = os.path.join(self.local_dir, cache_key['Path']) + + if os.name == 'nt': + cache_full_path = cache_full_path.replace('/', os.sep) + + cache_file_dir = os.path.dirname(cache_full_path) + if not os.path.exists(cache_file_dir): + os.makedirs(cache_file_dir, exist_ok=True) + + if not os.path.exists(model_file_location): + raise RuntimeError(f"Source file does not exist: {model_file_location}") + + move(model_file_location, cache_full_path) + self.cached_files.append(cache_key) + self.save_cached_files() + return cache_full_path + except (OSError, IOError) as e: + raise RuntimeError(f"Failed to put file to cache: {e}") diff --git a/pycsghub/cli.py b/pycsghub/cli.py index 3d81690..9506f0a 100644 --- a/pycsghub/cli.py +++ b/pycsghub/cli.py @@ -1,14 +1,24 @@ -import typer -import os +import getpass import logging +import os +import sys +import traceback +from functools import wraps +from importlib.metadata import version +from pathlib import Path from typing import List, Optional + +import typer from typing_extensions import Annotated + +from pycsghub._token import _get_token_from_file, _clean_token from pycsghub.cmd import repo, inference, finetune from pycsghub.cmd.repo_types import RepoType -from importlib.metadata import version +from pycsghub.constants import CSGHUB_TOKEN_PATH from pycsghub.constants import DEFAULT_CSGHUB_DOMAIN, DEFAULT_REVISION -from .upload_large_folder.main import upload_large_folder_internal from pycsghub.constants import REPO_SOURCE_CSG +from pycsghub.utils import validate_repo_id, get_token_to_send +from .upload_large_folder.main import upload_large_folder_internal logger = logging.getLogger(__name__) @@ -17,65 +27,224 @@ no_args_is_help=True, ) + def version_callback(value: bool): if value: pkg_version = version("csghub-sdk") print(f"csghub-cli version {pkg_version}") raise typer.Exit() + +def auto_inject_token_and_verbose(func): + """Decorator: automatically inject token and verbose parameters""" + @wraps(func) + def wrapper(*args, **kwargs): + if 'token' in kwargs and kwargs['token'] is None: + kwargs['token'] = get_token_to_send() + if kwargs.get('verbose', False): + print(f"[DEBUG] Auto-detected token: {'*' * 10 if kwargs['token'] else 'None'}") + + verbose = kwargs.get('verbose', False) + if verbose: + print(f"[DEBUG] Arguments received:") + for key, value in kwargs.items(): + if key == 'token' and value: + print(f"[DEBUG] {key}: {'*' * 10}") + else: + print(f"[DEBUG] {key}: {value}") + + return func(*args, **kwargs) + return wrapper + + OPTIONS = { "repoID": typer.Argument(help="The ID of the repo. (e.g. `username/repo-name`)."), - "localPath": typer.Argument(help="Local path to the file to upload. Defaults to the relative path of the file of repo of OpenCSG Hub."), - "pathInRepo": typer.Argument(help="Path of the folder in the repo. Defaults to the relative path of the file or folder."), + "localPath": typer.Argument( + help="Local path to the file to upload. Defaults to the relative path of the file of repo of OpenCSG Hub."), + "pathInRepo": typer.Argument( + help="Path of the folder in the repo. Defaults to the relative path of the file or folder."), "repoType": typer.Option("-t", "--repo-type", help="Specify the repository type."), "revision": typer.Option("-r", "--revision", help="An optional Git revision id which can be a branch name"), "cache_dir": typer.Option("-cd", "--cache-dir", help="Path to the directory where to save the downloaded files."), - "local_dir": typer.Option("-ld", "--local-dir", help="If provided, the downloaded files will be placed under this directory."), + "local_dir": typer.Option("-ld", "--local-dir", + help="If provided, the downloaded files will be placed under this directory."), "endpoint": typer.Option("-e", "--endpoint", help="The address of the request to be sent."), "username": typer.Option("-u", "--username", help="Logon account of OpenCSG Hub."), - "token": typer.Option("-k", "--token", help="A User Access Token generated from https://opencsg.com/settings/access-token"), + "token": typer.Option("-k", "--token", + help="A User Access Token generated from https://opencsg.com/settings/access-token"), "allow_patterns": typer.Option("--allow-patterns", help="Allow patterns for files to be downloaded."), "ignore_patterns": typer.Option("--ignore-patterns", help="Ignore patterns for files to be downloaded."), - "version": typer.Option(None, "-V", "--version", callback=version_callback, is_eager=True, help="Show the version and exit."), + "version": typer.Option(None, "-V", "--version", callback=version_callback, is_eager=True, + help="Show the version and exit."), "limit": typer.Option("--limit", help="Number of items to list"), "localFolder": typer.Argument(help="Local path to the folder to upload."), - "num_workers": typer.Option("-n","--num-workers", help="Number of concurrent upload workers."), - "print_report": typer.Option("--print-report", help="Whether to print a report of the upload progress. Defaults to True."), - "print_report_every": typer.Option("--print-report-every", help="Frequency at which the report is printed. Defaults to 60 seconds."), + "num_workers": typer.Option("-n", "--num-workers", help="Number of concurrent upload workers."), + "print_report": typer.Option("--print-report", + help="Whether to print a report of the upload progress. Defaults to True."), + "print_report_every": typer.Option("--print-report-every", + help="Frequency at which the report is printed. Defaults to 60 seconds."), "log_level": typer.Option("INFO", "-L", "--log-level", - help="set log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", - case_sensitive=False, - ), + help="set log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", + case_sensitive=False, + ), "source": typer.Option("--source", help="Specify the source of the repository (e.g. 'csg', 'hf', 'ms')."), } + +@app.command(name="login", help="Login to OpenCSG Hub", no_args_is_help=True) +def login( + token: Annotated[Optional[str], OPTIONS["token"]] = None): + """Login to OpenCSG Hub using your access token. + + You can get your access token from https://opencsg.com/settings/access-token + + Examples: + csghub-cli login + csghub-cli login --token your_token_here + """ + + try: + if token is None: + env_token = os.environ.get("CSGHUB_TOKEN") + if env_token: + logger.info("✅ Found token in environment variable CSGHUB_TOKEN") + token = env_token + else: + logger.error("❌ No token found in environment variable CSGHUB_TOKEN") + logger.info("Please use 'csghub-cli login --token your_token_here' to provide your token") + logger.info("You can get your access token from https://opencsg.com/settings/access-token") + raise typer.Exit(1) + + cleaned_token = _clean_token(token) + if not cleaned_token: + logger.error("❌ Error: Invalid token provided.") + raise typer.Exit(1) + + if len(cleaned_token) < 10: + logger.error("❌ Error: Token seems too short. Please check your token.") + raise typer.Exit(1) + + try: + token_path = Path(CSGHUB_TOKEN_PATH) + token_path.parent.mkdir(parents=True, exist_ok=True) + + token_path.write_text(cleaned_token) + + if os.name != 'nt': + os.chmod(token_path, 0o600) + + logger.info("✅ Token saved successfully!") + logger.info(f"Token location: {token_path}") + logger.info("You can now use csghub-cli commands without specifying --token each time.") + + except PermissionError as e: + logger.error(f"❌ Permission error saving token: {e}") + logger.error("Please check if you have write permissions to the token directory.") + raise typer.Exit(1) + except Exception as e: + logger.error(f"❌ Error saving token: {e}") + raise typer.Exit(1) + except Exception as e: + logger.error(f"❌ Error in login command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + + +@app.command(name="logout", help="Logout from OpenCSG Hub", no_args_is_help=True) +def logout(): + """Remove your access token from the local machine. + + This will delete the stored token file. + """ + + try: + token_path = Path(CSGHUB_TOKEN_PATH) + if token_path.exists(): + token_path.unlink() + logger.info("✅ Successfully logged out!") + logger.info("Your access token has been removed from this machine.") + else: + logger.info("ℹ️ No stored token found. You are already logged out.") + + except Exception as e: + logger.error(f"❌ Error in logout command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + + +@app.command(name="whoami", help="Show current user information", no_args_is_help=True) +def whoami( + token: Annotated[Optional[str], OPTIONS["token"]] = None, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, +): + """Show information about the currently logged in user. + + This command will display your username and verify that your token is valid. + """ + + try: + if token is None: + token = _get_token_from_file() + + if not token: + logger.error("❌ Not logged in. Please run 'csghub-cli login' first.") + raise typer.Exit(1) + + try: + logger.info("✅ Logged in successfully!") + logger.info(f"Token location: {CSGHUB_TOKEN_PATH}") + logger.info("Note: Token validation requires API access. Please test with a download/upload command.") + + except Exception as e: + logger.error(f"❌ Error verifying token: {e}") + raise typer.Exit(1) + except Exception as e: + logger.error(f"❌ Error in whoami command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + + @app.command(name="download", help="Download model/dataset from OpenCSG Hub", no_args_is_help=True) +@auto_inject_token_and_verbose def download( repo_id: Annotated[str, OPTIONS["repoID"]], - repo_type: Annotated[RepoType, OPTIONS["repoType"]] = RepoType.MODEL, + repo_type: Annotated[RepoType, OPTIONS["repoType"]] = RepoType.MODEL, revision: Annotated[Optional[str], OPTIONS["revision"]] = DEFAULT_REVISION, endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, token: Annotated[Optional[str], OPTIONS["token"]] = None, cache_dir: Annotated[Optional[str], OPTIONS["cache_dir"]] = None, local_dir: Annotated[Optional[str], OPTIONS["local_dir"]] = None, - allow_patterns: Annotated[Optional[List[str]], OPTIONS["allow_patterns"]] = None, - ignore_patterns: Annotated[Optional[List[str]], OPTIONS["ignore_patterns"]] = None, + allow_patterns: Annotated[Optional[str], OPTIONS["allow_patterns"]] = None, + ignore_patterns: Annotated[Optional[str], OPTIONS["ignore_patterns"]] = None, source: Annotated[str, OPTIONS["source"]] = REPO_SOURCE_CSG, - ): - repo.download( - repo_id=repo_id, - repo_type=repo_type.value, - revision=revision, - cache_dir=cache_dir, - local_dir=local_dir, - endpoint=endpoint, - token=token, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - source=source, - ) +): + try: + repo.download( + repo_id=repo_id, + repo_type=repo_type.value, + revision=revision, + cache_dir=cache_dir, + local_dir=local_dir, + endpoint=endpoint, + token=token, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + source=source, + enable_parallel=False, + max_parallel_workers=4, + ) + except Exception as e: + logger.error(f"❌ Error in download command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + @app.command(name="upload", help="Upload repository files to OpenCSG Hub", no_args_is_help=True) +@auto_inject_token_and_verbose def upload( repo_id: Annotated[str, OPTIONS["repoID"]], local_path: Annotated[str, OPTIONS["localPath"]], @@ -85,58 +254,82 @@ def upload( endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, token: Annotated[Optional[str], OPTIONS["token"]] = None, user_name: Annotated[Optional[str], OPTIONS["username"]] = "", - ): - # File upload - if os.path.isfile(local_path): - repo.upload_files( - repo_id=repo_id, - repo_type=repo_type.value, - repo_file=local_path, - path_in_repo=path_in_repo, - revision=revision, - endpoint=endpoint, - token=token - ) - # Folder upload - else: - repo.upload_folder( - repo_id=repo_id, - repo_type=repo_type.value, +): + try: + validate_repo_id(repo_id) + + if os.path.isfile(local_path): + repo.upload_files( + repo_id=repo_id, + repo_type=repo_type.value, + repo_file=local_path, + path_in_repo=path_in_repo, + revision=revision, + endpoint=endpoint, + token=token, + ) + # Folder upload + else: + repo.upload_folder( + repo_id=repo_id, + repo_type=repo_type.value, + local_path=local_path, + path_in_repo=path_in_repo, + revision=revision, + endpoint=endpoint, + token=token, + user_name=user_name, + ) + except ValueError as e: + logger.error(f"{e}") + sys.exit(1) + except Exception as e: + logger.error(f"❌ Error in upload command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + + +@app.command(name="upload-large-folder", help="Upload large folder to OpenCSG Hub using multiple workers", + no_args_is_help=True) +@auto_inject_token_and_verbose +def upload_large_folder( + repo_id: Annotated[str, OPTIONS["repoID"]], + local_path: Annotated[str, OPTIONS["localFolder"]], + repo_type: Annotated[RepoType, OPTIONS["repoType"]] = RepoType.MODEL, + revision: Annotated[Optional[str], OPTIONS["revision"]] = DEFAULT_REVISION, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, + token: Annotated[Optional[str], OPTIONS["token"]] = None, + allow_patterns: Annotated[Optional[List[str]], OPTIONS["allow_patterns"]] = None, + ignore_patterns: Annotated[Optional[List[str]], OPTIONS["ignore_patterns"]] = None, + num_workers: Annotated[int, OPTIONS["num_workers"]] = None, + print_report: Annotated[bool, OPTIONS["print_report"]] = False, + print_report_every: Annotated[int, OPTIONS["print_report_every"]] = 60, +): + try: + validate_repo_id(repo_id) + + upload_large_folder_internal( + repo_id=repo_id, local_path=local_path, - path_in_repo=path_in_repo, + repo_type=repo_type.value, revision=revision, endpoint=endpoint, token=token, - user_name=user_name + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + num_workers=num_workers, + print_report=print_report, + print_report_every=print_report_every, ) - -@app.command(name="upload-large-folder", help="Upload large folder to OpenCSG Hub using multiple workers", no_args_is_help=True) -def upload_large_folder( - repo_id: Annotated[str, OPTIONS["repoID"]], - local_path: Annotated[str, OPTIONS["localFolder"]], - repo_type: Annotated[RepoType, OPTIONS["repoType"]] = RepoType.MODEL, - revision: Annotated[Optional[str], OPTIONS["revision"]] = DEFAULT_REVISION, - endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, - token: Annotated[Optional[str], OPTIONS["token"]] = None, - allow_patterns: Annotated[Optional[List[str]], OPTIONS["allow_patterns"]] = None, - ignore_patterns: Annotated[Optional[List[str]], OPTIONS["ignore_patterns"]] = None, - num_workers: Annotated[int, OPTIONS["num_workers"]] = None, - print_report: Annotated[bool, OPTIONS["print_report"]] = False, - print_report_every: Annotated[int, OPTIONS["print_report_every"]] = 60, -): - upload_large_folder_internal( - repo_id=repo_id, - local_path=local_path, - repo_type=repo_type.value, - revision=revision, - endpoint=endpoint, - token=token, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - num_workers=num_workers, - print_report=print_report, - print_report_every=print_report_every, - ) + except ValueError as e: + logger.error(f"{e}") + sys.exit(1) + except Exception as e: + logger.error(f"❌ Error in upload-large-folder command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) inference_app = typer.Typer( @@ -145,96 +338,145 @@ def upload_large_folder( ) app.add_typer(inference_app, name="inference") + @inference_app.command(name="list", help="List inference instances", no_args_is_help=True) +@auto_inject_token_and_verbose def list_inferences( - user_name: Annotated[str, OPTIONS["username"]], - token: Annotated[str, OPTIONS["token"]] = None, - endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, - limit: Annotated[Optional[int], OPTIONS["limit"]] = 50, + user_name: Annotated[str, OPTIONS["username"]], + token: Annotated[str, OPTIONS["token"]] = None, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, + limit: Annotated[Optional[int], OPTIONS["limit"]] = 50, ): - inference.list( - user_name=user_name, - token=token, - endpoint=endpoint, - limit=limit, - ) + try: + inference.list( + user_name=user_name, + token=token, + endpoint=endpoint, + limit=limit, + ) + except Exception as e: + logger.error(f"❌ Error in inference list command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + @inference_app.command(name="start", help="Start inference instance", no_args_is_help=True) +@auto_inject_token_and_verbose def start_inference( - model: str = typer.Argument(..., help="model to use for inference"), - id: int = typer.Argument(..., help="ID of the inference instance to start"), - token: Annotated[Optional[str], OPTIONS["token"]] = None, - endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, + model: str = typer.Argument(..., help="model to use for inference"), + id: int = typer.Argument(..., help="ID of the inference instance to start"), + token: Annotated[Optional[str], OPTIONS["token"]] = None, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, ): - inference.start( - id=id, - model=model, - token=token, - endpoint=endpoint, - ) + try: + inference.start( + id=id, + model=model, + token=token, + endpoint=endpoint, + ) + except Exception as e: + logger.error(f"❌ Error in inference start command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) @inference_app.command(name="stop", help="Stop inference instance", no_args_is_help=True) +@auto_inject_token_and_verbose def stop_inference( - model: str = typer.Argument(..., help="model to use for inference"), - id: int = typer.Argument(..., help="ID of the inference instance to start"), - token: Annotated[Optional[str], OPTIONS["token"]] = None, - endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, + model: str = typer.Argument(..., help="model to use for inference"), + id: int = typer.Argument(..., help="ID of the inference instance to start"), + token: Annotated[Optional[str], OPTIONS["token"]] = None, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, ): - inference.stop( - id=id, - model=model, - token=token, - endpoint=endpoint, - ) - + try: + inference.stop( + id=id, + model=model, + token=token, + endpoint=endpoint, + ) + except Exception as e: + logger.error(f"❌ Error in inference stop command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + + finetune_app = typer.Typer( no_args_is_help=True, help="Manage fine-tuning instances on OpenCSG Hub" ) app.add_typer(finetune_app, name="finetune") + @finetune_app.command(name="list", help="List fine-tuning instances", no_args_is_help=True) +@auto_inject_token_and_verbose def list_finetune( - user_name: Annotated[str, OPTIONS["username"]], - token: Annotated[str, OPTIONS["token"]] = None, - endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, - limit: Annotated[Optional[int], OPTIONS["limit"]] = 50, + user_name: Annotated[str, OPTIONS["username"]], + token: Annotated[str, OPTIONS["token"]] = None, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, + limit: Annotated[Optional[int], OPTIONS["limit"]] = 50, ): - finetune.list( - user_name=user_name, - token=token, - endpoint=endpoint, - limit=limit, - ) + try: + finetune.list( + user_name=user_name, + token=token, + endpoint=endpoint, + limit=limit, + ) + except Exception as e: + logger.error(f"❌ Error in finetune list command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + @finetune_app.command(name="start", help="Start fine-tuning instance", no_args_is_help=True) +@auto_inject_token_and_verbose def start_finetune( - model: str = typer.Argument(..., help="model to use for fine-tuning"), - id: int = typer.Argument(..., help="ID of the fine-tuning instance to start"), - token: Annotated[Optional[str], OPTIONS["token"]] = None, - endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, + model: str = typer.Argument(..., help="model to use for fine-tuning"), + id: int = typer.Argument(..., help="ID of the fine-tuning instance to start"), + token: Annotated[Optional[str], OPTIONS["token"]] = None, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, ): - finetune.start( - id=id, - model=model, - token=token, - endpoint=endpoint, - ) + try: + finetune.start( + id=id, + model=model, + token=token, + endpoint=endpoint, + ) + except Exception as e: + logger.error(f"❌ Error in finetune start command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + @finetune_app.command(name="stop", help="Stop fine-tuning instance", no_args_is_help=True) +@auto_inject_token_and_verbose def stop_finetune( - model: str = typer.Argument(..., help="model to use for fine-tuning"), - id: int = typer.Argument(..., help="ID of the fine-tuning instance to stop"), - token: Annotated[Optional[str], OPTIONS["token"]] = None, - endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, + model: str = typer.Argument(..., help="model to use for fine-tuning"), + id: int = typer.Argument(..., help="ID of the fine-tuning instance to stop"), + token: Annotated[Optional[str], OPTIONS["token"]] = None, + endpoint: Annotated[Optional[str], OPTIONS["endpoint"]] = DEFAULT_CSGHUB_DOMAIN, ): - finetune.stop( - id=id, - model=model, - token=token, - endpoint=endpoint, - ) + try: + finetune.stop( + id=id, + model=model, + token=token, + endpoint=endpoint, + ) + except Exception as e: + logger.error(f"❌ Error in finetune stop command: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) + @app.callback( invoke_without_command=True, @@ -245,8 +487,7 @@ def stop_finetune( } ) def main( - version: bool = OPTIONS["version"], - log_level: str = OPTIONS["log_level"] + log_level: str = OPTIONS["log_level"] ): # for example: format='%(asctime)s - %(name)s:%(funcName)s:%(lineno)d - %(levelname)s - %(message)s', logging.basicConfig( @@ -259,4 +500,11 @@ def main( if __name__ == "__main__": - app() + + try: + app() + except Exception as e: + logger.error(f"❌ Error: {e}") + logger.error("📋 Full stack trace:") + traceback.print_exc() + sys.exit(1) diff --git a/pycsghub/cmd/finetune.py b/pycsghub/cmd/finetune.py index ffb855e..d0c9063 100644 --- a/pycsghub/cmd/finetune.py +++ b/pycsghub/cmd/finetune.py @@ -1,12 +1,14 @@ import requests + from pycsghub.utils import (build_csg_headers, get_endpoint) + def list( - user_name: str, - token: str, - endpoint: str, - limit: int, + user_name: str, + token: str, + endpoint: str, + limit: int, ): action_endpoint = get_endpoint(endpoint=endpoint) url = f"{action_endpoint}/api/v1/user/{user_name}/finetune/instances" @@ -24,15 +26,16 @@ def list( if instances: for instance in instances: print(f"{instance['deploy_id']:<10}" - f"{instance['deploy_name']:<40}" - f"{instance['model_id']:<50}" - f"{instance['status']:<10}") + f"{instance['deploy_name']:<40}" + f"{instance['model_id']:<50}" + f"{instance['status']:<10}") + def start( - id: int, - model: str, - token: str, - endpoint: str, + id: int, + model: str, + token: str, + endpoint: str, ): action_endpoint = get_endpoint(endpoint=endpoint) url = f"{action_endpoint}/api/v1/models/{model}/finetune/{id}/start" @@ -43,11 +46,12 @@ def start( result = response.json() print(result) + def stop( - id: int, - model: str, - token: str, - endpoint: str, + id: int, + model: str, + token: str, + endpoint: str, ): action_endpoint = get_endpoint(endpoint=endpoint) url = f"{action_endpoint}/api/v1/models/{model}/finetune/{id}/stop" diff --git a/pycsghub/cmd/inference.py b/pycsghub/cmd/inference.py index f396904..934b030 100644 --- a/pycsghub/cmd/inference.py +++ b/pycsghub/cmd/inference.py @@ -1,13 +1,14 @@ import requests + from pycsghub.utils import (build_csg_headers, get_endpoint) def list( - user_name: str, - token: str, - endpoint: str, - limit: int, + user_name: str, + token: str, + endpoint: str, + limit: int, ): action_endpoint = get_endpoint(endpoint=endpoint) url = f"{action_endpoint}/api/v1/user/{user_name}/run/model" @@ -26,15 +27,16 @@ def list( if instances: for instance in instances: print(f"{instance['deploy_id']:<10}" - f"{instance['deploy_name']:<40}" - f"{instance['model_id']:<50}" - f"{instance['status']:<10}") + f"{instance['deploy_name']:<40}" + f"{instance['model_id']:<50}" + f"{instance['status']:<10}") + def detail( - id: int, - model: str, - token: str, - endpoint: str, + id: int, + model: str, + token: str, + endpoint: str, ): action_endpoint = get_endpoint(endpoint=endpoint) url = f"{action_endpoint}/api/v1/models/{model}/run/{id}" @@ -45,11 +47,12 @@ def detail( response.raise_for_status() return response.status_code + def start( - id: int, - model: str, - token: str, - endpoint: str, + id: int, + model: str, + token: str, + endpoint: str, ): detail(id=id, model=model, token=token, endpoint=endpoint) action_endpoint = get_endpoint(endpoint=endpoint) @@ -61,11 +64,12 @@ def start( result = response.json() print(result) + def stop( - id: int, - model: str, - token: str, - endpoint: str, + id: int, + model: str, + token: str, + endpoint: str, ): action_endpoint = get_endpoint(endpoint=endpoint) url = f"{action_endpoint}/api/v1/models/{model}/run/{id}/stop" diff --git a/pycsghub/cmd/repo.py b/pycsghub/cmd/repo.py index 35337fc..f3daaa1 100644 --- a/pycsghub/cmd/repo.py +++ b/pycsghub/cmd/repo.py @@ -1,11 +1,13 @@ -from pycsghub.snapshot_download import snapshot_download -from pycsghub.file_upload import http_upload_file from pathlib import Path -from typing import Optional, Union, List +from typing import Optional, Union + from pycsghub.constants import DEFAULT_REVISION +from pycsghub.file_upload import http_upload_file from pycsghub.repository import Repository +from pycsghub.snapshot_download import snapshot_download from pycsghub.utils import get_token_to_send + def download( repo_id: str, repo_type: str, @@ -14,23 +16,28 @@ def download( local_dir: Union[str, Path, None] = None, endpoint: Optional[str] = None, token: Optional[str] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[str] = None, + ignore_patterns: Optional[str] = None, source: str = None, - ): + enable_parallel: bool = False, + max_parallel_workers: int = 4, +): snapshot_download( repo_id=repo_id, repo_type=repo_type, revision=revision, cache_dir=cache_dir, local_dir=local_dir, - endpoint=endpoint, + endpoint=endpoint, token=token, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, source=source, + enable_parallel=enable_parallel, + max_parallel_workers=max_parallel_workers, ) + def upload_files( repo_id: str, repo_type: str, @@ -38,8 +45,17 @@ def upload_files( path_in_repo: Optional[str] = "", revision: Optional[str] = DEFAULT_REVISION, endpoint: Optional[str] = None, - token: Optional[str] = None - ): + token: Optional[str] = None, + verbose: bool = False, +): + if verbose: + print(f"[DEBUG] Uploading file: {repo_file}") + print(f"[DEBUG] To repo: {repo_id}") + print(f"[DEBUG] Repo type: {repo_type}") + print(f"[DEBUG] Path in repo: {path_in_repo}") + print(f"[DEBUG] Revision: {revision}") + print(f"[DEBUG] Endpoint: {endpoint}") + http_upload_file( repo_id=repo_id, repo_type=repo_type, @@ -48,8 +64,10 @@ def upload_files( revision=revision, endpoint=endpoint, token=token, + verbose=verbose, ) + def upload_folder( repo_id: str, repo_type: str, @@ -64,7 +82,19 @@ def upload_folder( user_name: Optional[str] = "", token: Optional[str] = None, auto_create: Optional[bool] = True, - ): + verbose: bool = False, +): + if verbose: + print(f"[DEBUG] Uploading folder: {local_path}") + print(f"[DEBUG] To repo: {repo_id}") + print(f"[DEBUG] Repo type: {repo_type}") + print(f"[DEBUG] Path in repo: {path_in_repo}") + print(f"[DEBUG] Work dir: {work_dir}") + print(f"[DEBUG] Revision: {revision}") + print(f"[DEBUG] Endpoint: {endpoint}") + print(f"[DEBUG] User name: {user_name}") + print(f"[DEBUG] Auto create: {auto_create}") + r = Repository( repo_id=repo_id, upload_path=local_path, @@ -79,6 +109,6 @@ def upload_folder( user_name=user_name, token=get_token_to_send(token), auto_create=auto_create, + verbose=verbose, ) r.upload() - diff --git a/pycsghub/cmd/repo_types.py b/pycsghub/cmd/repo_types.py index a00fcba..5b5b330 100644 --- a/pycsghub/cmd/repo_types.py +++ b/pycsghub/cmd/repo_types.py @@ -1,7 +1,8 @@ -import typer from enum import Enum + from pycsghub.constants import REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE + class RepoType(str, Enum): MODEL = REPO_TYPE_MODEL DATASET = REPO_TYPE_DATASET diff --git a/pycsghub/constants.py b/pycsghub/constants.py index d0fbcd9..ff5edf5 100644 --- a/pycsghub/constants.py +++ b/pycsghub/constants.py @@ -1,4 +1,5 @@ import os + API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 API_FILE_DOWNLOAD_TIMEOUT = 5 API_FILE_DOWNLOAD_RETRY_TIMES = 5 @@ -17,8 +18,19 @@ OPERATION_ACTION_GIT = "git" OPERATION_ACTION = [OPERATION_ACTION_API, OPERATION_ACTION_GIT] + +def _get_token_path(): + """Get the token path based on the operating system.""" + if os.environ.get("CSGHUB_TOKEN_PATH"): + return os.environ.get("CSGHUB_TOKEN_PATH") + + home_dir = os.path.expanduser("~") + token_dir = os.path.join(home_dir, ".csghub") + return os.path.join(token_dir, "token") + + CSGHUB_HOME = os.environ.get('CSGHUB_HOME', '/home') -CSGHUB_TOKEN_PATH = os.environ.get("CSGHUB_TOKEN_PATH", os.path.join(CSGHUB_HOME, "token")) +CSGHUB_TOKEN_PATH = _get_token_path() MODEL_ID_SEPARATOR = '/' DEFAULT_CSG_GROUP = 'OpenCSG' @@ -89,8 +101,7 @@ *.webp filter=lfs diff=lfs merge=lfs -text """ - S3_INTERNAL = os.environ.get("S3_INTERNAL", '') GIT_HIDDEN_DIR = ".git" -GIT_ATTRIBUTES_FILE = ".gitattributes" \ No newline at end of file +GIT_ATTRIBUTES_FILE = ".gitattributes" diff --git a/pycsghub/csghub_api.py b/pycsghub/csghub_api.py index 066f2f0..e65034b 100644 --- a/pycsghub/csghub_api.py +++ b/pycsghub/csghub_api.py @@ -1,12 +1,15 @@ +import base64 import logging from typing import Dict -from pycsghub.utils import (build_csg_headers, get_endpoint, model_id_to_group_owner_name) + import requests -import base64 + from pycsghub.constants import GIT_ATTRIBUTES_CONTENT, DEFAULT_REVISION, DEFAULT_LICENCE, REPO_TYPE_SPACE +from pycsghub.utils import (build_csg_headers, get_endpoint, model_id_to_group_owner_name) logger = logging.getLogger(__name__) + class CsgHubApi: ''' csghub API wrapper class @@ -16,13 +19,13 @@ def __init__(self): pass def fetch_upload_modes( - self, - payload: Dict, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + self, + payload: Dict, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): """ Requests repo files upload modes @@ -40,15 +43,15 @@ def fetch_upload_modes( raise ValueError(f"invalid json data for fetch upload modes from {fetch_url} response: {response.text}") def fetch_lfs_batch_info( - self, - payload: Dict, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, - local_file: str, - upload_id: str, + self, + payload: Dict, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, + local_file: str, + upload_id: str, ): """ Requests the LFS batch endpoint to retrieve upload instructions @@ -64,16 +67,17 @@ def fetch_lfs_batch_info( try: return response.json() except ValueError: - raise ValueError(f"invalid json data for fetch LFS {local_file} batch info from {batch_url} response: {response.text}") - + raise ValueError( + f"invalid json data for fetch LFS {local_file} batch info from {batch_url} response: {response.text}") + def create_commit( - self, - payload: Dict, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + self, + payload: Dict, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): """ Creates a commit in the given repo, deleting & uploading files as needed. @@ -91,12 +95,12 @@ def create_commit( raise ValueError(f"invalid json data for create files commit on {commit_url} response: {response.text}") def repo_branch_exists( - self, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + self, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): """ Check if repo and branch exists @@ -108,7 +112,7 @@ def repo_branch_exists( action_url = f"{action_endpoint}/api/v1/{repo_type}s/{repo_id}/branches" response = requests.get(action_url, headers=req_headers) logger.debug(f"fetch {repo_type} {repo_id} branches on {action_url} response: {response.text}") - + if response.status_code != 200: return False, False jsonRes = response.json() @@ -119,16 +123,16 @@ def repo_branch_exists( for b in branches: if b["name"] == revision: return True, True - + return True, False - + def create_new_branch( - self, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + self, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): """ Create branch @@ -138,7 +142,7 @@ def create_new_branch( "Content-Type": "application/json" }) action_url = f"{action_endpoint}/api/v1/{repo_type}s/{repo_id}/raw/.gitattributes" - + GIT_ATTRIBUTES_CONTENT_BASE64 = base64.b64encode(GIT_ATTRIBUTES_CONTENT.encode()).decode() data = { @@ -146,27 +150,29 @@ def create_new_branch( "new_branch": revision, "content": GIT_ATTRIBUTES_CONTENT_BASE64 } - + response = requests.post(action_url, json=data, headers=req_headers) if response.status_code != 200: - logger.error(f"create new branch {revision} for {repo_type} {repo_id} on {action_endpoint} response: {response.text}") + logger.error( + f"create new branch {revision} for {repo_type} {repo_id} on {action_endpoint} response: {response.text}") response.raise_for_status() try: return response.json() except ValueError: - raise ValueError(f"invalid json data for create new branch {revision} for {repo_type} {repo_id} on {action_url} response: {response.text}") + raise ValueError( + f"invalid json data for create new branch {revision} for {repo_type} {repo_id} on {action_url} response: {response.text}") def create_new_repo( - self, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + self, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): """ Create new repo - """ + """ action_endpoint = get_endpoint(endpoint=endpoint) req_headers = build_csg_headers(token=token, headers={ "Content-Type": "application/json" @@ -181,7 +187,7 @@ def create_new_repo( "private": True, "license": DEFAULT_LICENCE, } - + if repo_type == REPO_TYPE_SPACE: resource_resp = self.get_space_resources(endpoint=endpoint) resources = resource_resp["data"] @@ -190,7 +196,7 @@ def create_new_repo( data["resource_id"] = resource_id else: raise ValueError(f"no any space resource found for create {repo_type} {repo_id}") - + response = requests.post(action_url, json=data, headers=req_headers) if response.status_code != 200: logger.error(f"create new {repo_type} {repo_id} on {action_endpoint} response: {response.text}") @@ -198,11 +204,12 @@ def create_new_repo( try: return response.json() except ValueError: - raise ValueError(f"invalid json data for create new {repo_type} {repo_id} on {action_url} response: {response.text}") + raise ValueError( + f"invalid json data for create new {repo_type} {repo_id} on {action_url} response: {response.text}") def get_space_resources( - self, - endpoint: str, + self, + endpoint: str, ): """ Get space resources diff --git a/pycsghub/errors.py b/pycsghub/errors.py index 42917c0..1bad87a 100644 --- a/pycsghub/errors.py +++ b/pycsghub/errors.py @@ -1,4 +1,3 @@ - class NotSupportError(Exception): pass @@ -32,4 +31,4 @@ class FileIntegrityError(Exception): class FileDownloadError(Exception): - pass \ No newline at end of file + pass diff --git a/pycsghub/file_download.py b/pycsghub/file_download.py index 15cc4a1..02f0646 100644 --- a/pycsghub/file_download.py +++ b/pycsghub/file_download.py @@ -1,27 +1,39 @@ +import logging +import os import tempfile +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed from functools import partial from http.cookiejar import CookieJar from pathlib import Path -from typing import Optional, Union, List, Dict +from typing import Dict, List, Optional, Union + import requests from huggingface_hub.utils import filter_repo_objects -from requests.adapters import Retry from tqdm import tqdm +from urllib3.util.retry import Retry + from pycsghub import utils from pycsghub.cache import ModelFileSystemCache -from pycsghub.utils import (build_csg_headers, - get_cache_dir, - model_id_to_group_owner_name, - pack_repo_file_info, - get_file_download_url, - get_endpoint) -from pycsghub.constants import (API_FILE_DOWNLOAD_RETRY_TIMES, - API_FILE_DOWNLOAD_TIMEOUT, - API_FILE_DOWNLOAD_CHUNK_SIZE, - DEFAULT_REVISION) -from pycsghub.errors import FileDownloadError -import os +from pycsghub.constants import DEFAULT_REVISION, REPO_TYPES +from pycsghub.constants import REPO_TYPE_MODEL from pycsghub.errors import InvalidParameter +from pycsghub.utils import (get_cache_dir, + pack_repo_file_info, + get_endpoint, + build_csg_headers) +from pycsghub.utils import (get_file_download_url, + model_id_to_group_owner_name, + get_model_temp_dir) + +# API constants +API_FILE_DOWNLOAD_RETRY_TIMES = 3 +API_FILE_DOWNLOAD_TIMEOUT = 30 + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) def try_to_load_from_cache(): @@ -40,6 +52,171 @@ def get_csg_hub_url(): pass +class MultiThreadDownloader: + """Multi-threaded downloader""" + + def __init__(self, max_workers=4, chunk_size=8192, retry_times=3, timeout=30): + self.max_workers = max_workers + self.chunk_size = chunk_size + self.retry_times = retry_times + self.timeout = timeout + self.lock = threading.Lock() + + def download_file_with_retry(self, url: str, file_path: str, headers: dict = None, + cookies: CookieJar = None, token: str = None, + file_name: str = None, progress_bar: tqdm = None) -> bool: + """Download a single file with retry mechanism""" + headers = headers or {} + get_headers = build_csg_headers(token=token, headers=headers) + + for attempt in range(self.retry_times + 1): + try: + logger.info( + f"Start downloading file: {file_name or os.path.basename(file_path)} (attempt {attempt + 1}/{self.retry_times + 1})") + + temp_file_path = file_path + '.tmp' + + with open(temp_file_path, 'wb') as f: + response = requests.get( + url, + headers=get_headers, + stream=True, + cookies=cookies, + timeout=self.timeout + ) + response.raise_for_status() + + total_size = int(response.headers.get('content-length', 0)) + + if progress_bar: + progress_bar.total = total_size + progress_bar.set_description(f"Downloading {file_name or os.path.basename(file_path)}") + + downloaded_size = 0 + for chunk in response.iter_content(chunk_size=self.chunk_size): + if chunk: + f.write(chunk) + downloaded_size += len(chunk) + if progress_bar: + progress_bar.update(len(chunk)) + + if total_size > 0: + actual_size = os.path.getsize(temp_file_path) + if actual_size != total_size: + logger.warning(f"File size mismatch: expected {total_size}, actual {actual_size}") + if attempt < self.retry_times: + logger.info(f"Retry downloading...") + continue + else: + logger.error(f"File size verification failed, reached maximum retry count") + return False + + os.rename(temp_file_path, file_path) + logger.info(f"File download successful: {file_path}") + return True + + except requests.exceptions.RequestException as e: + logger.error(f"Download failed (attempt {attempt + 1}/{self.retry_times + 1}): {e}") + if attempt < self.retry_times: + wait_time = 2 ** attempt + logger.info(f"Waiting {wait_time} seconds before retrying...") + time.sleep(wait_time) + else: + logger.error(f"File download failed, reached maximum retry count: {file_name or os.path.basename(file_path)}") + return False + except Exception as e: + logger.error(f"Unknown error (attempt {attempt + 1}/{self.retry_times + 1}): {e}") + if attempt < self.retry_times: + wait_time = 2 ** attempt + logger.info(f"Waiting {wait_time} seconds before retrying...") + time.sleep(wait_time) + else: + logger.error(f"File download failed, reached maximum retry count: {file_name or os.path.basename(file_path)}") + return False + + return False + + def download_files_parallel(self, download_tasks: List[Dict], + progress_bar: tqdm = None) -> Dict[str, bool]: + """Download multiple files in parallel""" + results = {} + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_file = {} + for task in download_tasks: + future = executor.submit( + self.download_file_with_retry, + url=task['url'], + file_path=task['file_path'], + headers=task.get('headers'), + cookies=task.get('cookies'), + token=task.get('token'), + file_name=task.get('file_name'), + progress_bar=progress_bar + ) + future_to_file[future] = task.get('file_name', os.path.basename(task['file_path'])) + + # Collect results + for future in as_completed(future_to_file): + file_name = future_to_file[future] + try: + success = future.result() + results[file_name] = success + except Exception as e: + logger.error(f"Download task exception: {file_name} - {e}") + results[file_name] = False + + return results + + def download_files_parallel_with_progress(self, download_tasks: List[Dict], + progress_bar: tqdm = None, + progress_tracker=None, + progress_callback=None) -> Dict[str, bool]: + """Download multiple files in parallel, support progress callback""" + results = {} + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_file = {} + for task in download_tasks: + future = executor.submit( + self.download_file_with_retry, + url=task['url'], + file_path=task['file_path'], + headers=task.get('headers'), + cookies=task.get('cookies'), + token=task.get('token'), + file_name=task.get('file_name'), + progress_bar=progress_bar + ) + future_to_file[future] = task.get('file_name', os.path.basename(task['file_path'])) + + for future in as_completed(future_to_file): + file_name = future_to_file[future] + try: + success = future.result() + results[file_name] = success + + if progress_tracker: + progress_tracker.update_progress(file_name, success) + + if progress_callback: + progress_info = progress_tracker.get_progress_info() + progress_callback(progress_info) + + except Exception as e: + logger.error(f"Download task exception: {file_name} - {e}") + results[file_name] = False + + if progress_tracker: + progress_tracker.update_progress(file_name, False) + + if progress_callback: + progress_info = progress_tracker.get_progress_info() + progress_callback(progress_info) + + return results + + def file_download( repo_id: str, *, @@ -56,14 +233,14 @@ def file_download( token: Optional[str] = None, repo_type: Optional[str] = None, source: Optional[str] = None, + max_workers: int = 4, + use_parallel: bool = True, ) -> str: if cache_dir is None: cache_dir = get_cache_dir(repo_type=repo_type) if isinstance(cache_dir, Path): cache_dir = str(cache_dir) - temporary_cache_dir = os.path.join(cache_dir, 'temp') - os.makedirs(temporary_cache_dir, exist_ok=True) - + if local_dir is not None and isinstance(local_dir, Path): local_dir = str(local_dir) @@ -71,7 +248,11 @@ def file_download( raise InvalidParameter('file_name cannot be None, if you want to load single file from repo {}'.format(repo_id)) group_or_owner, name = model_id_to_group_owner_name(repo_id) - name = name.replace('.', '___') + if os.name == 'nt': + name = name.replace('.', '___') + invalid_chars = '<>:"|?*' + for char in invalid_chars: + name = name.replace(char, '_') cache = ModelFileSystemCache(cache_dir, group_or_owner, name, local_dir=local_dir) @@ -85,7 +266,7 @@ def file_download( else: download_endpoint = get_endpoint(endpoint=endpoint) # make headers - # todo need to add cookies? + # todo need to add cookies? repo_info = utils.get_repo_info(repo_id=repo_id, revision=revision, token=token, @@ -105,36 +286,278 @@ def file_download( if file_name not in model_files: raise InvalidParameter('file {} not in repo {}'.format(file_name, repo_id)) - with tempfile.TemporaryDirectory(dir=temporary_cache_dir) as temp_cache_dir: - repo_file_info = pack_repo_file_info(file_name, revision) - if not cache.exists(repo_file_info): - file_name = os.path.basename(repo_file_info['Path']) - # get download url - url = get_file_download_url( - model_id=repo_id, - file_path=file_name, - revision=revision, - endpoint=download_endpoint, - repo_type=repo_type, - source=source) - # todo support parallel download api + model_temp_dir = get_model_temp_dir(cache_dir, f"{group_or_owner}/{name}") + + repo_file_info = pack_repo_file_info(file_name, revision) + if not cache.exists(repo_file_info): + file_name = os.path.basename(repo_file_info['Path']) + # get download url + url = get_file_download_url( + model_id=repo_id, + file_path=file_name, + revision=revision, + endpoint=download_endpoint, + repo_type=repo_type, + source=source) + + if use_parallel: + downloader = MultiThreadDownloader(max_workers=max_workers) + download_tasks = [{ + 'url': url, + 'file_path': os.path.join(model_temp_dir, file_name), + 'headers': headers, + 'cookies': cookies, + 'token': token, + 'file_name': file_name + }] + + with tqdm(total=1, desc=f"Downloading file", unit="file") as pbar: + results = downloader.download_files_parallel(download_tasks, pbar) + + if not results.get(file_name, False): + raise Exception(f"File download failed: {file_name}") + else: http_get( url=url, - local_dir=temp_cache_dir, + local_dir=model_temp_dir, file_name=file_name, headers=headers, cookies=cookies, token=token) - # todo using hash to check file integrity - temp_file = os.path.join(temp_cache_dir, file_name) - cache.put_file(repo_file_info, temp_file) - print(f"Saved file to '{temp_file}'") - else: - print(f'File {file_name} already in {cache.get_root_location()}, skip downloading!') + # todo using hash to check file integrity + temp_file = os.path.join(model_temp_dir, file_name) + cache.put_file(repo_file_info, temp_file) + print(f"Saved file to '{temp_file}'") + else: + print(f'File {file_name} already in {cache.get_root_location()}, skip downloading!') cache.save_model_version(revision_info={'Revision': revision}) return os.path.join(cache.get_root_location(), file_name) + +def snapshot_download_parallel( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = DEFAULT_REVISION, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + allow_patterns: Optional[str] = None, + ignore_patterns: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + endpoint: Optional[str] = None, + token: Optional[str] = None, + source: Optional[str] = None, + verbose: bool = False, + max_workers: int = 4, + use_parallel: bool = True, +) -> str: + """Download snapshot of the entire repository in parallel""" + if repo_type is None: + repo_type = REPO_TYPE_MODEL + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") + + # Convert string patterns to lists + if allow_patterns and isinstance(allow_patterns, str): + allow_patterns = [allow_patterns] + if ignore_patterns and isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + + if verbose: + print(f"[VERBOSE] Starting parallel download for repo_id: {repo_id}") + print(f"[VERBOSE] repo_type: {repo_type}") + print(f"[VERBOSE] revision: {revision}") + print(f"[VERBOSE] allow_patterns: {allow_patterns}") + print(f"[VERBOSE] ignore_patterns: {ignore_patterns}") + + if cache_dir is None: + cache_dir = get_cache_dir(repo_type=repo_type) + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if isinstance(local_dir, Path): + local_dir = str(local_dir) + elif isinstance(local_dir, str): + pass + else: + local_dir = str(Path.cwd() / repo_id) + + os.makedirs(local_dir, exist_ok=True) + if verbose: + print(f"[VERBOSE] Created/verified local_dir: {local_dir}") + + if verbose: + print(f"[VERBOSE] cache_dir: {cache_dir}") + print(f"[VERBOSE] local_dir: {local_dir}") + + group_or_owner, name = model_id_to_group_owner_name(repo_id) + + if verbose: + print(f"[VERBOSE] Parsed repo_id - owner: {group_or_owner}, name: {name}") + + cache = ModelFileSystemCache(cache_dir, group_or_owner, name, local_dir=local_dir) + + if local_files_only: + if len(cache.cached_files) == 0: + raise ValueError( + 'Cannot find the requested files in the cached path and outgoing' + ' traffic has been disabled. To enable model look-ups and downloads' + " online, set 'local_files_only' to False.") + return cache.get_root_location() + else: + download_endpoint = get_endpoint(endpoint=endpoint) + if verbose: + print(f"[VERBOSE] download_endpoint: {download_endpoint}") + + # make headers + # todo need to add cookies? + if verbose: + print(f"[VERBOSE] Getting repository info...") + + repo_info = utils.get_repo_info(repo_id, + repo_type=repo_type, + revision=revision, + token=token, + endpoint=download_endpoint, + source=source) + + assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." + assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." + + if verbose: + print(f"[VERBOSE] Repository SHA: {repo_info.sha}") + print(f"[VERBOSE] Total files in repository: {len(repo_info.siblings)}") + print(f"[VERBOSE] All files in repository:") + for sibling in repo_info.siblings: + print(f"[VERBOSE] - {sibling.rfilename}") + + repo_files = list( + filter_repo_objects( + items=[f.rfilename for f in repo_info.siblings], + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + ) + + if verbose: + print(f"[VERBOSE] Files after filtering: {len(repo_files)}") + for file in repo_files: + print(f"[VERBOSE] - {file}") + model_temp_dir = get_model_temp_dir(cache_dir, f"{group_or_owner}/{name}") + + if verbose: + print(f"[VERBOSE] model_temp_dir: {model_temp_dir}") + print(f"[VERBOSE] Starting parallel download for {len(repo_files)} files...") + + download_tasks = [] + files_to_download = [] + + for repo_file in repo_files: + if verbose: + print(f"[VERBOSE] Processing file: {repo_file}") + + repo_file_info = pack_repo_file_info(repo_file, revision) + if cache.exists(repo_file_info): + file_name = os.path.basename(repo_file_info['Path']) + print(f"File {file_name} already in '{cache.get_root_location()}', skip downloading!") + if verbose: + print(f"[VERBOSE] File already exists, skipping download") + continue + + if verbose: + print(f"[VERBOSE] File does not exist, preparing download...") + + # get download url + url = get_file_download_url( + model_id=repo_id, + file_path=repo_file, + repo_type=repo_type, + revision=revision, + endpoint=download_endpoint, + source=source) + + if verbose: + print(f"[VERBOSE] Download URL: {url}") + + # Prepare download tasks + download_tasks.append({ + 'url': url, + 'file_path': os.path.join(model_temp_dir, repo_file), + 'headers': headers, + 'cookies': cookies, + 'token': token, + 'file_name': repo_file + }) + files_to_download.append(repo_file) + + if download_tasks: + if use_parallel: + downloader = MultiThreadDownloader(max_workers=max_workers) + + with tqdm(total=len(download_tasks), desc="Parallel downloading files", unit="file") as pbar: + results = downloader.download_files_parallel(download_tasks, pbar) + + failed_files = [] + for file_name, success in results.items(): + if success: + temp_file = os.path.join(model_temp_dir, file_name) + repo_file_info = pack_repo_file_info(file_name, revision) + savedFile = cache.put_file(repo_file_info, temp_file) + print(f"Saved file to '{savedFile}'") + if verbose: + print(f"[VERBOSE] File successfully saved to: {savedFile}") + else: + failed_files.append(file_name) + logger.error(f"File download failed: {file_name}") + + if failed_files: + logger.error(f"Some files download failed: {failed_files}") + raise Exception(f"Some files download failed, please check network connection or retry") + else: + for repo_file in files_to_download: + if verbose: + print(f"[VERBOSE] Starting HTTP download for {repo_file}...") + + url = get_file_download_url( + model_id=repo_id, + file_path=repo_file, + repo_type=repo_type, + revision=revision, + endpoint=download_endpoint, + source=source) + + http_get( + url=url, + local_dir=model_temp_dir, + file_name=repo_file, + headers=headers, + cookies=cookies, + token=token) + + # todo using hash to check file integrity + temp_file = os.path.join(model_temp_dir, repo_file) + if verbose: + print(f"[VERBOSE] Temp file path: {temp_file}") + + repo_file_info = pack_repo_file_info(repo_file, revision) + savedFile = cache.put_file(repo_file_info, temp_file) + print(f"Saved file to '{savedFile}'") + + if verbose: + print(f"[VERBOSE] File successfully saved to: {savedFile}") + + cache.save_model_version(revision_info={'Revision': revision}) + + final_location = os.path.join(cache.get_root_location()) + if verbose: + print(f"[VERBOSE] Download completed. Final location: {final_location}") + + return final_location + + def http_get(*, url: str, local_dir: str, @@ -168,43 +591,42 @@ def http_get(*, r.raise_for_status() accept_ranges = r.headers.get('Accept-Ranges') content_length = r.headers.get('Content-Length') - if accept_ranges == 'bytes': - if downloaded_size == 0: - total_content_length = int(content_length) if content_length is not None else None - else: - if downloaded_size > 0: - temp_file.truncate(0) - downloaded_size = temp_file.tell() - total_content_length = int(content_length) if content_length is not None else None - - progress = tqdm( - unit='B', - unit_scale=True, - unit_divisor=1024, - total=total_content_length, - initial=downloaded_size, - desc="Downloading {}".format(file_name), - ) - for chunk in r.iter_content(chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE): + if content_length: + total_content_length = int(content_length) + if downloaded_size > 0 and accept_ranges != 'bytes': + # server doesn't support range requests, restart download + temp_file.seek(0) + temp_file.truncate() + downloaded_size = 0 + total_content_length = 0 + get_headers.pop('Range', None) + r = requests.get(url, headers=get_headers, stream=True, + cookies=cookies, timeout=API_FILE_DOWNLOAD_TIMEOUT) + r.raise_for_status() + content_length = r.headers.get('Content-Length') + if content_length: + total_content_length = int(content_length) + for chunk in r.iter_content(chunk_size=8192): if chunk: - progress.update(len(chunk)) temp_file.write(chunk) - progress.close() break - except Exception as e: - retry = retry.increment('GET', url, error=e) - retry.sleep() - - downloaded_length = os.path.getsize(temp_file.name) - if total_content_length != downloaded_length: - os.remove(temp_file.name) - msg = 'File %s download incomplete, content_length: %s but the file downloaded length: %s, please download again' % ( - file_name, total_content_length, downloaded_length) - raise FileDownloadError(msg) - # fix folder recursive issue - os.makedirs(os.path.dirname(os.path.join(local_dir, file_name)), exist_ok=True) - os.replace(temp_file.name, os.path.join(local_dir, file_name)) - return + except requests.exceptions.RequestException as e: + logger.error(f"Download failed: {e}") + if temp_file.tell() == 0: + raise e + else: + # partial download, continue + continue + temp_file.flush() + temp_file.close() + # move temp file to final location + final_file = os.path.join(local_dir, file_name) + os.rename(temp_file.name, final_file) + logger.debug(f"Downloaded {file_name} to {final_file}") + if total_content_length > 0: + actual_size = os.path.getsize(final_file) + if actual_size != total_content_length: + logger.error(f"Warning: Downloaded file size ({actual_size}) doesn't match expected size ({total_content_length})") if __name__ == '__main__': diff --git a/pycsghub/file_upload.py b/pycsghub/file_upload.py index 5c6ebbf..1d11c7e 100644 --- a/pycsghub/file_upload.py +++ b/pycsghub/file_upload.py @@ -1,9 +1,12 @@ import os -import requests from typing import Optional + +import requests + from pycsghub.constants import (DEFAULT_REVISION) from pycsghub.utils import (build_csg_headers, get_endpoint) + def http_upload_file( repo_id: str, repo_type: Optional[str] = None, @@ -12,20 +15,45 @@ def http_upload_file( revision: Optional[str] = DEFAULT_REVISION, endpoint: Optional[str] = None, token: Optional[str] = None, - ): + verbose: bool = False, +): + if verbose: + print(f"[DEBUG] Starting file upload...") + print(f"[DEBUG] File path: {file_path}") + print(f"[DEBUG] Repo ID: {repo_id}") + print(f"[DEBUG] Repo type: {repo_type}") + print(f"[DEBUG] Path in repo: {path_in_repo}") + print(f"[DEBUG] Revision: {revision}") + print(f"[DEBUG] Endpoint: {endpoint}") + if not os.path.exists(file_path): raise ValueError(f"file '{file_path}' does not exist") + destination_path = os.path.join(path_in_repo, os.path.basename(file_path)) if path_in_repo else file_path http_endpoint = endpoint if endpoint is not None else get_endpoint() if not http_endpoint.endswith("/"): http_endpoint += "/" http_url = http_endpoint + "api/v1/" + repo_type + "s/" + repo_id + "/upload_file" + + if verbose: + print(f"[DEBUG] HTTP URL: {http_url}") + print(f"[DEBUG] Destination path: {destination_path}") + post_headers = build_csg_headers(token=token) file_data = {'file': open(file_path, 'rb')} form_data = {'file_path': destination_path, 'branch': revision, 'message': 'upload' + file_path} + + if verbose: + print(f"[DEBUG] Sending POST request...") + response = requests.post(http_url, headers=post_headers, data=form_data, files=file_data) + + if verbose: + print(f"[DEBUG] Response status code: {response.status_code}") + print(f"[DEBUG] Response content: {response.content.decode()}") + if response.status_code == 200: print(f"file '{file_path}' upload successfully.") else: - print(f"fail to upload {file_path} with response code: {response.status_code}, error: {response.content.decode()}") - \ No newline at end of file + print( + f"fail to upload {file_path} with response code: {response.status_code}, error: {response.content.decode()}") diff --git a/pycsghub/repo_reader/__init__.py b/pycsghub/repo_reader/__init__.py index 52c7ac5..a0fb14a 100644 --- a/pycsghub/repo_reader/__init__.py +++ b/pycsghub/repo_reader/__init__.py @@ -1,2 +1,2 @@ +from .dataset.huggingface.load import * from .model.huggingface.model_auto import * -from .dataset.huggingface.load import * \ No newline at end of file diff --git a/pycsghub/repo_reader/dataset/huggingface/load.py b/pycsghub/repo_reader/dataset/huggingface/load.py index 13b254c..17473cd 100644 --- a/pycsghub/repo_reader/dataset/huggingface/load.py +++ b/pycsghub/repo_reader/dataset/huggingface/load.py @@ -1,41 +1,44 @@ from typing import Dict, Mapping, Optional, Sequence, Union + import datasets -from datasets.splits import Split -from datasets.features import Features +from datasets.arrow_dataset import Dataset +from datasets.dataset_dict import DatasetDict, IterableDatasetDict from datasets.download.download_config import DownloadConfig from datasets.download.download_manager import DownloadMode +from datasets.features import Features +from datasets.iterable_dataset import IterableDataset +from datasets.splits import Split from datasets.utils.info_utils import VerificationMode from datasets.utils.version import Version -from datasets.iterable_dataset import IterableDataset -from datasets.dataset_dict import DatasetDict, IterableDatasetDict -from datasets.arrow_dataset import Dataset + +from pycsghub.constants import REPO_TYPE_DATASET from pycsghub.snapshot_download import snapshot_download from pycsghub.utils import get_token_to_send -from pycsghub.constants import REPO_TYPE_DATASET + def load_dataset( - path: str, - name: Optional[str] = None, - data_dir: Optional[str] = None, - data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, - split: Optional[Union[str, Split]] = None, - cache_dir: Optional[str] = None, - features: Optional[Features] = None, - download_config: Optional[DownloadConfig] = None, - download_mode: Optional[Union[DownloadMode, str]] = None, - verification_mode: Optional[Union[VerificationMode, str]] = None, - ignore_verifications="deprecated", - keep_in_memory: Optional[bool] = None, - save_infos: bool = False, - revision: Optional[Union[str, Version]] = None, - token: Optional[Union[bool, str]] = None, - use_auth_token="deprecated", - task="deprecated", - streaming: bool = False, - num_proc: Optional[int] = None, - storage_options: Optional[Dict] = None, - trust_remote_code: bool = None, - **config_kwargs, + path: str, + name: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, + split: Optional[Union[str, Split]] = None, + cache_dir: Optional[str] = None, + features: Optional[Features] = None, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[Union[DownloadMode, str]] = None, + verification_mode: Optional[Union[VerificationMode, str]] = None, + ignore_verifications="deprecated", + keep_in_memory: Optional[bool] = None, + save_infos: bool = False, + revision: Optional[Union[str, Version]] = None, + token: Optional[Union[bool, str]] = None, + use_auth_token="deprecated", + task="deprecated", + streaming: bool = False, + num_proc: Optional[int] = None, + storage_options: Optional[Dict] = None, + trust_remote_code: bool = None, + **config_kwargs, ) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]: if token is None: try: @@ -66,4 +69,4 @@ def load_dataset( storage_options=storage_options, trust_remote_code=trust_remote_code, **config_kwargs - ) \ No newline at end of file + ) diff --git a/pycsghub/repo_reader/model/huggingface/model_auto.py b/pycsghub/repo_reader/model/huggingface/model_auto.py index 442eaf9..86350ad 100644 --- a/pycsghub/repo_reader/model/huggingface/model_auto.py +++ b/pycsghub/repo_reader/model/huggingface/model_auto.py @@ -1,8 +1,11 @@ +import os +from pathlib import Path + import transformers + from pycsghub.snapshot_download import snapshot_download from pycsghub.utils import get_token_to_send -import os -from pathlib import Path + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, @@ -21,7 +24,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, return model - class_names = [] for i in transformers.__all__: @@ -42,7 +44,3 @@ def from_pretrained(cls, pretrained_model_name_or_path, }) except AttributeError as e: print(e) - - - - diff --git a/pycsghub/repository.py b/pycsghub/repository.py index 6feb6be..a9b3ba3 100644 --- a/pycsghub/repository.py +++ b/pycsghub/repository.py @@ -1,23 +1,27 @@ +import base64 import os -from typing import Optional +import re +import shutil import subprocess -from typing import List, Optional, Union +import tempfile +import traceback from pathlib import Path -import requests -import base64 -import shutil -import re +from typing import List, Optional, Union from urllib.parse import urlparse -from pycsghub.constants import (GIT_ATTRIBUTES_CONTENT, - OPERATION_ACTION_GIT, - REPO_TYPE_DATASET, - REPO_TYPE_SPACE, + +import requests + +from pycsghub.constants import (GIT_ATTRIBUTES_CONTENT, + OPERATION_ACTION_GIT, + REPO_TYPE_DATASET, + REPO_TYPE_SPACE, REPO_TYPE_CODE) -from pycsghub.constants import (GIT_HIDDEN_DIR, GIT_ATTRIBUTES_FILE) +from pycsghub.constants import (GIT_HIDDEN_DIR) from pycsghub.utils import (build_csg_headers, model_id_to_group_owner_name, get_endpoint) + def ignore_folders(folder, contents): ignored = [] exclude_list = [GIT_HIDDEN_DIR] @@ -26,28 +30,33 @@ def ignore_folders(folder, contents): ignored.append(item) return ignored + class Repository: def __init__( - self, - repo_id: str, - upload_path: str, - path_in_repo: Optional[str] = "", - branch_name: Optional[str] = "main", - work_dir: Optional[str] = "/tmp/csg", - user_name: Optional[str] = "", - token: Optional[str] = "", - license: Optional[str] = "apache-2.0", - nickname: Optional[str] = "", - description: Optional[str] = "", - repo_type: Optional[str] = None, - endpoint: Optional[str] = None, - auto_create: Optional[bool] = True, - ): + self, + repo_id: str, + upload_path: str, + path_in_repo: Optional[str] = "", + branch_name: Optional[str] = "main", + work_dir: Optional[str] = "/tmp/csg", + user_name: Optional[str] = "", + token: Optional[str] = "", + license: Optional[str] = "apache-2.0", + nickname: Optional[str] = "", + description: Optional[str] = "", + repo_type: Optional[str] = None, + endpoint: Optional[str] = None, + auto_create: Optional[bool] = True, + verbose: bool = False, + ): self.repo_id = repo_id self.upload_path = upload_path self.path_in_repo = path_in_repo self.branch_name = branch_name - self.work_dir = work_dir + if os.name == "nt": + self.work_dir = os.path.join(tempfile.gettempdir(), "csg") + else: + self.work_dir = work_dir self.user_name = user_name self.token = token self.license = license @@ -56,11 +65,12 @@ def __init__( self.repo_type = repo_type self.endpoint = endpoint self.auto_create = auto_create + self.verbose = verbose self.repo_url_prefix = self.get_url_prefix() self.namespace, self.name = model_id_to_group_owner_name(model_id=self.repo_id) self.repo_dir = os.path.join(self.work_dir, self.name) self.user_name = self.user_name if self.user_name else self.namespace - + def get_url_prefix(self): if self.repo_type == REPO_TYPE_DATASET: return "datasets" @@ -70,63 +80,143 @@ def get_url_prefix(self): return "codes" else: return "models" - + def upload(self) -> None: + if self.verbose: + print(f"[DEBUG] Starting upload process...") + print(f"[DEBUG] Upload path: {self.upload_path}") + print(f"[DEBUG] Work dir: {self.work_dir}") + print(f"[DEBUG] Repo ID: {self.repo_id}") + print(f"[DEBUG] Repo type: {self.repo_type}") + if not os.path.exists(self.upload_path): - raise ValueError("upload path does not exist") - + raise ValueError("upload path does not exist") + if not os.path.exists(self.work_dir): os.makedirs(self.work_dir, exist_ok=True) - + if self.auto_create: + if self.verbose: + print(f"[DEBUG] Auto-creating repo and branch...") self.auto_create_repo_and_branch() - - if os.path.exists(self.repo_dir): - shutil.rmtree(self.repo_dir) - + repo_url = self.generate_repo_clone_url() - self.git_clone(branch_name=self.branch_name, repo_url=repo_url) + if self.verbose: + print(f"[DEBUG] Repo URL: {repo_url}") + + if os.path.exists(self.repo_dir): + try: + if self.verbose: + print(f"[DEBUG] Repository exists, pulling latest changes...") + self.git_pull(work_dir=self.repo_dir) + except Exception as e: + if self.verbose: + print(f"[DEBUG] Pull failed, removing and re-cloning: {str(e)}") + print(f"Update repository failed, re-cloning: {str(e)}") + try: + shutil.rmtree(self.repo_dir) + except PermissionError as e: + print(traceback.format_exc()) + raise Exception("permission denied,please run this program with administrator privileges") + self.git_clone(branch_name=self.branch_name, repo_url=repo_url) + else: + if self.verbose: + print(f"[DEBUG] Repository doesn't exist, cloning...") + self.git_clone(branch_name=self.branch_name, repo_url=repo_url) + + if self.verbose: + print(f"[DEBUG] Copying files to repository...") git_cmd_workdir = self.copy_repo_files() - + + if self.verbose: + print(f"[DEBUG] Tracking large files...") self.track_large_files(work_dir=git_cmd_workdir) + + if self.verbose: + print(f"[DEBUG] Adding files to git...") self.git_add(work_dir=git_cmd_workdir) + + if self.verbose: + print(f"[DEBUG] Committing changes...") self.git_commit(work_dir=git_cmd_workdir) + number_of_commits = self.commits_to_push(work_dir=git_cmd_workdir) + if self.verbose: + print(f"[DEBUG] Commits to push: {number_of_commits}") + if number_of_commits > 1: + if self.verbose: + print(f"[DEBUG] Pushing changes to remote...") self.git_push(work_dir=git_cmd_workdir) def copy_repo_files(self): - from_path = "" - git_cmd_workdir = "" - + """Copy files to repository directory, optimized version""" from_path = self.upload_path git_cmd_workdir = self.repo_dir - destination_path = git_cmd_workdir - + path_suffix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else "" path_suffix = re.sub(r'^\./', '', path_suffix) - - destination_path = os.path.join(destination_path, path_suffix) - + + try: + destination_path = os.path.join(git_cmd_workdir, path_suffix) + os.path.normpath(destination_path) + except (OSError, ValueError) as e: + print(f"Path encoding error: {e}") + destination_path = os.path.join(git_cmd_workdir, "upload") + if not os.path.exists(destination_path): os.makedirs(destination_path, exist_ok=True) - - for item in os.listdir(destination_path): - item_path = os.path.join(destination_path, item) - if item != GIT_HIDDEN_DIR and item != GIT_ATTRIBUTES_FILE: - if os.path.isfile(item_path): - os.remove(item_path) - elif os.path.isdir(item_path): - shutil.rmtree(item_path) - + if os.path.isfile(self.upload_path): - shutil.copyfile(self.upload_path, destination_path) + try: + filename = os.path.basename(self.upload_path) + safe_filename = self._get_safe_filename(filename) + destination_file_path = os.path.join(destination_path, safe_filename) + shutil.copyfile(self.upload_path, destination_file_path) + except (OSError, UnicodeError) as e: + print(f"File copy failed: {e}") + destination_file_path = os.path.join(destination_path, "uploaded_file") + shutil.copyfile(self.upload_path, destination_file_path) else: - shutil.copytree(from_path, destination_path, dirs_exist_ok=True, ignore=ignore_folders) + try: + shutil.copytree(from_path, destination_path, dirs_exist_ok=True, ignore=ignore_folders) + except (OSError, UnicodeError) as e: + print(f"Directory copy failed: {e}") + self._copy_files_individually(from_path, destination_path) return git_cmd_workdir + def _get_safe_filename(self, filename): + """Get safe filename, handle encoding issues""" + try: + filename.encode('utf-8').decode('utf-8') + return filename + except UnicodeError: + import hashlib + safe_name = hashlib.md5(filename.encode('utf-8', errors='ignore')).hexdigest() + ext = os.path.splitext(filename)[1] + return f"file_{safe_name}{ext}" + + def _copy_files_individually(self, from_path, destination_path): + """Copy files individually, handle encoding issues""" + if not os.path.exists(from_path): + return + + for item in os.listdir(from_path): + try: + source_item = os.path.join(from_path, item) + dest_item = os.path.join(destination_path, self._get_safe_filename(item)) + + if os.path.isfile(source_item): + shutil.copyfile(source_item, dest_item) + elif os.path.isdir(source_item): + os.makedirs(dest_item, exist_ok=True) + self._copy_files_individually(source_item, dest_item) + except (OSError, UnicodeError) as e: + print(f"Skip file {item}: {e}") + continue + def auto_create_repo_and_branch(self): repoExist, branchExist = self.repo_exists() if not repoExist: @@ -135,7 +225,7 @@ def auto_create_repo_and_branch(self): err_msg = f"fail to create new repo for {self.repo_id} with http status code '{response.status_code}' and message '{response.text}'" raise ValueError(err_msg) repoExist, branchExist = self.repo_exists() - + if not branchExist: response = self.create_new_branch() if response.status_code != 200: @@ -154,9 +244,9 @@ def repo_exists(self): response = requests.get(url, headers=headers) if response.status_code != 200: return False, False - + response.raise_for_status() - + jsonRes = response.json() if jsonRes["msg"] != "OK": return True, False @@ -165,13 +255,13 @@ def repo_exists(self): for b in branches: if b["name"] == self.branch_name: return True, True - + return True, False - + def create_new_branch(self): action_endpoint = get_endpoint(endpoint=self.endpoint) url = f"{action_endpoint}/api/v1/{self.repo_url_prefix}/{self.repo_id}/raw/.gitattributes" - + GIT_ATTRIBUTES_CONTENT_BASE64 = base64.b64encode(GIT_ATTRIBUTES_CONTENT.encode()).decode() data = { @@ -179,7 +269,7 @@ def create_new_branch(self): "new_branch": self.branch_name, "content": GIT_ATTRIBUTES_CONTENT_BASE64 } - + headers = build_csg_headers(token=self.token, headers={ "Content-Type": "application/json" }) @@ -202,7 +292,7 @@ def create_new_repo(self): "license": self.license, "description": self.description, } - + headers = build_csg_headers(token=self.token, headers={ "Content-Type": "application/json" }) @@ -220,15 +310,22 @@ def generate_repo_clone_url(self) -> str: return clone_url def git_clone( - self, - branch_name: str, - repo_url: str + self, + branch_name: str, + repo_url: str ) -> subprocess.CompletedProcess: try: env = os.environ.copy() env.update({"GIT_LFS_SKIP_SMUDGE": "1"}) + if os.name == "nt": + try: + self.run_subprocess("git config --global core.quotepath false".split(), folder=self.work_dir, + check=False) + except: + pass # ignore configuration error, continue execution + result = self.run_subprocess( - command=f"git clone -b {branch_name} {repo_url}", + command=f"git clone -b {branch_name} {repo_url}", folder=self.work_dir, check=True, env=env @@ -236,11 +333,11 @@ def git_clone( except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) return result - + def git_add( - self, - work_dir: str, - pattern: str = "." + self, + work_dir: str, + pattern: str = "." ) -> subprocess.CompletedProcess: try: result = self.run_subprocess("git add -v".split() + [pattern], work_dir) @@ -249,9 +346,9 @@ def git_add( return result def git_commit( - self, - work_dir: str, - commit_message: str = "commit files to CSGHub" + self, + work_dir: str, + commit_message: str = "commit files to CSGHub" ) -> subprocess.CompletedProcess: try: result = self.run_subprocess("git commit -v -m".split() + [commit_message], work_dir) @@ -259,12 +356,17 @@ def git_commit( if len(exc.stderr) > 0: raise EnvironmentError(exc.stderr) else: - raise EnvironmentError(exc.stdout) + err_str = exc.stdout + if "nothing to commit, working tree clean" in err_str: + print(err_str) + exit() + else: + raise EnvironmentError(exc.stdout) return result def git_push( - self, - work_dir: str, + self, + work_dir: str, ) -> subprocess.CompletedProcess: try: result = self.run_subprocess("git push".split(), work_dir) @@ -274,7 +376,21 @@ def git_push( else: raise EnvironmentError(exc.stdout) return result - + + def git_pull( + self, + work_dir: str, + ) -> subprocess.CompletedProcess: + """Update repository to the latest version""" + try: + result = self.run_subprocess("git pull".split(), work_dir) + except subprocess.CalledProcessError as exc: + if len(exc.stderr) > 0: + raise EnvironmentError(exc.stderr) + else: + raise EnvironmentError(exc.stdout) + return result + def commits_to_push(self, work_dir: Union[str, Path]) -> int: try: result = self.run_subprocess(f"git cherry -v", work_dir) @@ -290,12 +406,13 @@ def track_large_files(self, work_dir: str, pattern: str = ".") -> List[str]: for filename in self.list_files_to_be_staged(work_dir=work_dir, pattern=pattern): if filename in deleted_files: continue - + path_to_file = os.path.join(os.getcwd(), work_dir, filename) size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) - if size_in_mb >= 1 and not self.is_tracked_with_lfs(filename=path_to_file) and not self.is_git_ignored(filename=path_to_file): - self.lfs_track(work_dir=work_dir,patterns=filename) + if size_in_mb >= 1 and not self.is_tracked_with_lfs(filename=path_to_file) and not self.is_git_ignored( + filename=path_to_file): + self.lfs_track(work_dir=work_dir, patterns=filename) files_to_be_tracked_with_lfs.append(filename) self.lfs_untrack(work_dir=work_dir, patterns=deleted_files) @@ -304,6 +421,14 @@ def track_large_files(self, work_dir: str, pattern: str = ".") -> List[str]: def list_files_to_be_staged(self, work_dir: str, pattern: str = ".") -> List[str]: try: + try: + self.run_subprocess("git config --global core.quotepath false".split(), work_dir) + except subprocess.CalledProcessError: + try: + self.run_subprocess("git config core.quotepath false".split(), work_dir) + except subprocess.CalledProcessError: + pass + p = self.run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], work_dir) if len(p.stdout.strip()): files = p.stdout.strip().split("\n") @@ -313,7 +438,7 @@ def list_files_to_be_staged(self, work_dir: str, pattern: str = ".") -> List[str raise EnvironmentError(exc.stderr) return files - + def list_deleted_files(self, work_dir: str) -> List[str]: try: git_status = self.run_subprocess("git status -s", work_dir).stdout.strip() @@ -340,7 +465,6 @@ def lfs_track(self, work_dir: str, patterns: Union[str, List[str]], filename: bo except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) - def lfs_untrack(self, work_dir: str, patterns: Union[str, List[str]]): if isinstance(patterns, str): patterns = [patterns] @@ -385,12 +509,12 @@ def is_git_ignored(self, filename: Union[str, Path]) -> bool: return is_ignored def run_subprocess( - self, - command: Union[str, List[str]], - folder: Optional[Union[str, Path]] = None, - check: bool = True, - **kwargs, - ) -> subprocess.CompletedProcess: + self, + command: Union[str, List[str]], + folder: Optional[Union[str, Path]] = None, + check: bool = True, + **kwargs, + ) -> subprocess.CompletedProcess: if isinstance(command, str): command = command.split() @@ -406,4 +530,4 @@ def run_subprocess( errors="replace", cwd=folder or os.getcwd(), **kwargs, - ) + ) diff --git a/pycsghub/snapshot_download.py b/pycsghub/snapshot_download.py index 2e3867a..9efa040 100644 --- a/pycsghub/snapshot_download.py +++ b/pycsghub/snapshot_download.py @@ -1,19 +1,71 @@ +import logging import os -import tempfile from http.cookiejar import CookieJar from pathlib import Path -from typing import Dict, List, Optional, Union -from pycsghub.utils import (get_file_download_url, - model_id_to_group_owner_name) +from typing import Dict, Optional, Union, Callable, List +import threading + +from huggingface_hub.utils import filter_repo_objects +from tqdm import tqdm + +from pycsghub import utils from pycsghub.cache import ModelFileSystemCache +from pycsghub.constants import DEFAULT_REVISION, REPO_TYPES +from pycsghub.constants import REPO_TYPE_MODEL +from pycsghub.file_download import http_get, MultiThreadDownloader from pycsghub.utils import (get_cache_dir, pack_repo_file_info, get_endpoint) -from huggingface_hub.utils import filter_repo_objects -from pycsghub.file_download import http_get -from pycsghub.constants import DEFAULT_REVISION, REPO_TYPES -from pycsghub import utils -from pycsghub.constants import REPO_TYPE_MODEL +from pycsghub.utils import (get_file_download_url, + model_id_to_group_owner_name, + get_model_temp_dir) + +logger = logging.getLogger(__name__) + + +class DownloadProgressTracker: + """Download progress tracker""" + + def __init__(self, total_files: int): + self.total_files = total_files + self.current_downloaded = 0 + self.success_count = 0 + self.failed_count = 0 + self.successful_files = [] + self.remaining_files = [] + self.lock = threading.Lock() + + def update_progress(self, file_name: str, success: bool): + """Update download progress""" + with self.lock: + self.current_downloaded += 1 + if success: + self.success_count += 1 + self.successful_files.append(file_name) + else: + self.failed_count += 1 + + if file_name in self.remaining_files: + self.remaining_files.remove(file_name) + + def get_progress_info(self) -> Dict: + """Get current progress information""" + with self.lock: + return { + 'total_files': self.total_files, + 'current_downloaded': self.current_downloaded, + 'success_count': self.success_count, + 'failed_count': self.failed_count, + 'successful_files': self.successful_files.copy(), + 'remaining_count': len(self.remaining_files), + 'remaining_files': self.remaining_files.copy() + } + + def set_remaining_files(self, files: List[str]): + """Set remaining file list""" + with self.lock: + self.remaining_files = files.copy() + def snapshot_download( repo_id: str, @@ -24,30 +76,58 @@ def snapshot_download( local_dir: Union[str, Path, None] = None, local_files_only: Optional[bool] = False, cookies: Optional[CookieJar] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[str] = None, + ignore_patterns: Optional[str] = None, headers: Optional[Dict[str, str]] = None, endpoint: Optional[str] = None, token: Optional[str] = None, source: Optional[str] = None, + enable_parallel: bool = False, + max_parallel_workers: int = 4, + progress_callback: Optional[Callable[[Dict], None]] = None, ) -> str: if repo_type is None: repo_type = REPO_TYPE_MODEL if repo_type not in REPO_TYPES: raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") + + # Convert string patterns to lists + if allow_patterns and isinstance(allow_patterns, str): + allow_patterns = [allow_patterns] + if ignore_patterns and isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + + logger.debug(f"Starting download forepo_id: {repo_id}") + logger.debug(f"repo_type: {repo_type}") + logger.debug(f"revision: {revision}") + logger.debug(f"allow_patterns: {allow_patterns}") + logger.debug(f"ignore_patterns: {ignore_patterns}") + logger.debug(f"enable_parallel: {enable_parallel}") + logger.debug(f"max_parallel_workers: {max_parallel_workers}") + if cache_dir is None: cache_dir = get_cache_dir(repo_type=repo_type) if isinstance(cache_dir, Path): cache_dir = str(cache_dir) - temporary_cache_dir = os.path.join(cache_dir, 'temp') - os.makedirs(temporary_cache_dir, exist_ok=True) - - if local_dir is not None and isinstance(local_dir, Path): + + if isinstance(local_dir, Path): local_dir = str(local_dir) + elif isinstance(local_dir, str): + pass + else: + local_dir = str(Path.cwd()) + + os.makedirs(local_dir, exist_ok=True) + + logger.debug(f"created/verified local_dir: {local_dir}") + logger.debug(f"cache_dir: {cache_dir}") + logger.debug(f"local_dir: {local_dir}") group_or_owner, name = model_id_to_group_owner_name(repo_id) # name = name.replace('.', '___') + logger.debug(f"parsed repo_id - owner: {group_or_owner}, name: {name}") + cache = ModelFileSystemCache(cache_dir, group_or_owner, name, local_dir=local_dir) if local_files_only: @@ -59,8 +139,8 @@ def snapshot_download( return cache.get_root_location() else: download_endpoint = get_endpoint(endpoint=endpoint) - # make headers - # todo need to add cookies? + logger.debug(f"download_endpoint: {download_endpoint}") + repo_info = utils.get_repo_info(repo_id, repo_type=repo_type, revision=revision, @@ -70,6 +150,10 @@ def snapshot_download( assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." + + logger.debug(f"repository SHA: {repo_info.sha}") + logger.debug(f"total files in repository: {len(repo_info.siblings)}") + repo_files = list( filter_repo_objects( items=[f.rfilename for f in repo_info.siblings], @@ -78,35 +162,176 @@ def snapshot_download( ) ) - with tempfile.TemporaryDirectory(dir=temporary_cache_dir) as temp_cache_dir: - for repo_file in repo_files: - repo_file_info = pack_repo_file_info(repo_file, revision) - if cache.exists(repo_file_info): - file_name = os.path.basename(repo_file_info['Path']) - print(f"File {file_name} already in '{cache.get_root_location()}', skip downloading!") - continue - - # get download url - url = get_file_download_url( - model_id=repo_id, - file_path=repo_file, - repo_type=repo_type, - revision=revision, - endpoint=download_endpoint, - source=source) - # todo support parallel download api - http_get( - url=url, - local_dir=temp_cache_dir, - file_name=repo_file, - headers=headers, - cookies=cookies, - token=token) - - # todo using hash to check file integrity - temp_file = os.path.join(temp_cache_dir, repo_file) - savedFile = cache.put_file(repo_file_info, temp_file) - print(f"Saved file to '{savedFile}'") - + model_temp_dir = get_model_temp_dir(cache_dir, f"{group_or_owner}/{name}") + + if enable_parallel: + snapshot_download_with_multi_thread( + repo_id=repo_id, + repo_type=repo_type, + repo_files=repo_files, + revision=revision, + cache=cache, + download_endpoint=download_endpoint, + source=source, + headers=headers, + cookies=cookies, + token=token, + model_temp_dir=model_temp_dir, + max_parallel_workers=max_parallel_workers, + progress_callback=progress_callback, + ) + else: + snapshot_download_with_single_thread( + repo_id=repo_id, + repo_type=repo_type, + repo_files=repo_files, + revision=revision, + cache=cache, + download_endpoint=download_endpoint, + source=source, + headers=headers, + cookies=cookies, + token=token, + model_temp_dir=model_temp_dir, + progress_callback=progress_callback, + ) + cache.save_model_version(revision_info={'Revision': revision}) - return os.path.join(cache.get_root_location()) + + final_location = os.path.join(cache.get_root_location()) + logger.debug(f"download completed. Final location: {final_location}") + return final_location + +def snapshot_download_with_single_thread( + repo_id: str, + repo_type: str, + repo_files: list, + revision: str, + cache: ModelFileSystemCache, + download_endpoint: str, + source: str, + headers: Dict[str, str], + cookies: CookieJar, + token: str, + model_temp_dir: str, + progress_callback: Optional[Callable[[Dict], None]] +): + files_to_download = [] + for repo_file in repo_files: + repo_file_info = pack_repo_file_info(repo_file, revision) + if not cache.exists(repo_file_info): + files_to_download.append(repo_file) + + progress_tracker = DownloadProgressTracker(len(files_to_download)) + progress_tracker.set_remaining_files(files_to_download) + + for repo_file in repo_files: + repo_file_info = pack_repo_file_info(repo_file, revision) + if cache.exists(repo_file_info): + file_name = os.path.basename(repo_file_info['Path']) + logger.info(f"File {file_name} already in '{cache.get_root_location()}', skip downloading!") + continue + + # get download url + url = get_file_download_url( + model_id=repo_id, + file_path=repo_file, + repo_type=repo_type, + revision=revision, + endpoint=download_endpoint, + source=source) + + try: + http_get( + url=url, + local_dir=model_temp_dir, + file_name=repo_file, + headers=headers, + cookies=cookies, + token=token) + + # todo using hash to check file integrity + temp_file = os.path.join(model_temp_dir, repo_file) + savedFile = cache.put_file(repo_file_info, temp_file) + logger.info(f"Saved file to '{savedFile}'") + + progress_tracker.update_progress(repo_file, True) + except Exception as e: + logger.error(f"File download failed: {repo_file} - {e}") + progress_tracker.update_progress(repo_file, False) + + if progress_callback: + progress_info = progress_tracker.get_progress_info() + progress_callback(progress_info) + + +def snapshot_download_with_multi_thread( + repo_id: str, + repo_type: str, + repo_files: list, + revision: str, + cache: ModelFileSystemCache, + download_endpoint: str, + source: str, + headers: Dict[str, str], + cookies: CookieJar, + token: str, + model_temp_dir: str, + max_parallel_workers: int, + progress_callback: Optional[Callable[[Dict], None]], +): + download_tasks = [] + files_to_download = [] + + for repo_file in repo_files: + repo_file_info = pack_repo_file_info(repo_file, revision) + if cache.exists(repo_file_info): + file_name = os.path.basename(repo_file_info['Path']) + logger.info(f"File {file_name} already in '{cache.get_root_location()}', skip downloading!") + continue + + # get download url + url = get_file_download_url( + model_id=repo_id, + file_path=repo_file, + repo_type=repo_type, + revision=revision, + endpoint=download_endpoint, + source=source) + + download_tasks.append({ + 'url': url, + 'file_path': os.path.join(model_temp_dir, repo_file), + 'headers': headers, + 'cookies': cookies, + 'token': token, + 'file_name': repo_file + }) + files_to_download.append(repo_file) + + if download_tasks: + logger.info(f"Start parallel downloading {len(download_tasks)} files, using {max_parallel_workers} threads") + + progress_tracker = DownloadProgressTracker(len(download_tasks)) + progress_tracker.set_remaining_files(files_to_download) + + downloader = MultiThreadDownloader(max_workers=max_parallel_workers) + + with tqdm(total=len(download_tasks), desc="Parallel downloading files", unit="file") as pbar: + results = downloader.download_files_parallel_with_progress( + download_tasks, pbar, progress_tracker, progress_callback) + + failed_files = [] + for file_name, success in results.items(): + if success: + temp_file = os.path.join(model_temp_dir, file_name) + repo_file_info = pack_repo_file_info(file_name, revision) + savedFile = cache.put_file(repo_file_info, temp_file) + logger.info(f"Saved file to '{savedFile}'") + else: + failed_files.append(file_name) + logger.error(f"File download failed: {file_name}") + + if failed_files: + logger.error(f"Some files download failed: {failed_files}") + raise Exception(f"Some files download failed, please check network connection or retry") diff --git a/pycsghub/test/repo_reader/reader_test.py b/pycsghub/test/repo_reader/reader_test.py index 6edf10a..6c2c5b4 100644 --- a/pycsghub/test/repo_reader/reader_test.py +++ b/pycsghub/test/repo_reader/reader_test.py @@ -1,6 +1,7 @@ import unittest + from pycsghub.repo_reader import AutoModelForCausalLM -from pathlib import Path + class MyTestCase(unittest.TestCase): def test_something(self): diff --git a/pycsghub/test/snapshot_download_test.py b/pycsghub/test/snapshot_download_test.py index 51971a6..9b0fb58 100644 --- a/pycsghub/test/snapshot_download_test.py +++ b/pycsghub/test/snapshot_download_test.py @@ -1,7 +1,8 @@ import unittest -from pycsghub.snapshot_download import snapshot_download -from pycsghub.file_download import file_download + from pycsghub.errors import InvalidParameter +from pycsghub.file_download import file_download +from pycsghub.snapshot_download import snapshot_download class MyTestCase(unittest.TestCase): diff --git a/pycsghub/test/utils_test.py b/pycsghub/test/utils_test.py index d4ca215..763d210 100644 --- a/pycsghub/test/utils_test.py +++ b/pycsghub/test/utils_test.py @@ -1,10 +1,13 @@ import unittest + from pycsghub.utils import model_info + class MyTestCase(unittest.TestCase): token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_id = 'wayne0019/lwfmodel' + def test_something(self): self.assertEqual(True, False) # add assertion here @@ -16,7 +19,5 @@ def test_model_info(self): print(fetched_model_info.siblings) - - if __name__ == '__main__': unittest.main() diff --git a/pycsghub/upload_large_folder/consts.py b/pycsghub/upload_large_folder/consts.py index fc41f5f..ecbb439 100644 --- a/pycsghub/upload_large_folder/consts.py +++ b/pycsghub/upload_large_folder/consts.py @@ -1,5 +1,3 @@ -from typing import Literal - REPO_REGULAR_TYPE = "regular" REPO_LFS_TYPE = "lfs" @@ -16,7 +14,7 @@ ".git/*", "*/.git", "**/.git/**", - ".cache/csghub", + ".cache/csghub", ".cache/csghub/*", "*/.cache/csghub", "**/.cache/csghub/**", diff --git a/pycsghub/upload_large_folder/fixes.py b/pycsghub/upload_large_folder/fixes.py index 3d69261..b61a935 100644 --- a/pycsghub/upload_large_folder/fixes.py +++ b/pycsghub/upload_large_folder/fixes.py @@ -6,6 +6,7 @@ except ImportError: from json import JSONDecodeError # type: ignore # noqa: F401 import contextlib +import logging import os import shutil import stat @@ -18,18 +19,18 @@ from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout from . import consts -import logging logger = logging.getLogger(__name__) yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore + @contextlib.contextmanager def SoftTemporaryDirectory( - suffix: Optional[str] = None, - prefix: Optional[str] = None, - dir: Optional[Union[Path, str]] = None, - **kwargs, + suffix: Optional[str] = None, + prefix: Optional[str] = None, + dir: Optional[Union[Path, str]] = None, + **kwargs, ) -> Generator[Path, None, None]: tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs) yield Path(tmpdir.name).resolve() diff --git a/pycsghub/upload_large_folder/hashlib.py b/pycsghub/upload_large_folder/hashlib.py index 90e11e5..8f8f714 100644 --- a/pycsghub/upload_large_folder/hashlib.py +++ b/pycsghub/upload_large_folder/hashlib.py @@ -1,7 +1,8 @@ import functools -import hashlib import sys +import hashlib + _kwargs = {"usedforsecurity": False} if sys.version_info >= (3, 9) else {} md5 = functools.partial(hashlib.md5, **_kwargs) sha1 = functools.partial(hashlib.sha1, **_kwargs) diff --git a/pycsghub/upload_large_folder/jobs.py b/pycsghub/upload_large_folder/jobs.py index 1c3fbd8..1c0cc81 100644 --- a/pycsghub/upload_large_folder/jobs.py +++ b/pycsghub/upload_large_folder/jobs.py @@ -1,20 +1,22 @@ import logging -import time import queue -from typing import Any, Dict, List, Optional, Tuple, TypeVar -from .status import LargeUploadStatus, WorkerJob, JOB_ITEM_T +import time +from typing import List, Optional, Tuple + from .consts import WAITING_TIME_IF_NO_TASKS, MAX_NB_LFS_FILES_PER_COMMIT, MAX_NB_REGULAR_FILES_PER_COMMIT +from .status import LargeUploadStatus, WorkerJob, JOB_ITEM_T logger = logging.getLogger(__name__) + def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]]: with status.lock: # Commit if more than 5 minutes since last commit attempt (and at least 1 file) if ( - status.nb_workers_commit == 0 - and status.queue_commit.qsize() > 0 - and status.last_commit_attempt is not None - and time.time() - status.last_commit_attempt > 5 * 60 + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.last_commit_attempt is not None + and time.time() - status.last_commit_attempt > 5 * 60 ): status.nb_workers_commit += 1 logger.debug("job: commit (more than 5 minutes since last commit attempt)") @@ -43,7 +45,7 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, status.nb_workers_preupload_lfs += 1 logger.debug("job: preupload LFS (no other worker preuploading LFS)") return (WorkerJob.PREUPLOAD_LFS, _get_one(status.queue_preupload_lfs)) - + # Compute sha256 if at least 1 file and no worker is computing sha256 elif status.queue_sha256.qsize() > 0 and status.nb_workers_sha256 == 0: status.nb_workers_sha256 += 1 @@ -58,7 +60,7 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, # Preupload LFS file if at least 1 file elif status.queue_preupload_lfs.qsize() > 0 and ( - status.nb_workers_preupload_lfs == 0 + status.nb_workers_preupload_lfs == 0 ): status.nb_workers_preupload_lfs += 1 logger.debug("job: preupload LFS") @@ -78,10 +80,10 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, # Commit if at least 1 file and 1 min since last commit attempt elif ( - status.nb_workers_commit == 0 - and status.queue_commit.qsize() > 0 - and status.last_commit_attempt is not None - and time.time() - status.last_commit_attempt > 1 * 60 + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.last_commit_attempt is not None + and time.time() - status.last_commit_attempt > 1 * 60 ): status.nb_workers_commit += 1 logger.debug("job: commit (1 min since last commit attempt)") @@ -90,14 +92,14 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, # Commit if at least 1 file all other queues are empty and all workers are waiting # e.g. when it's the last commit elif ( - status.nb_workers_commit == 0 - and status.queue_commit.qsize() > 0 - and status.queue_sha256.qsize() == 0 - and status.queue_get_upload_mode.qsize() == 0 - and status.queue_preupload_lfs.qsize() == 0 - and status.nb_workers_sha256 == 0 - and status.nb_workers_get_upload_mode == 0 - and status.nb_workers_preupload_lfs == 0 + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.queue_sha256.qsize() == 0 + and status.queue_get_upload_mode.qsize() == 0 + and status.queue_preupload_lfs.qsize() == 0 + and status.nb_workers_sha256 == 0 + and status.nb_workers_get_upload_mode == 0 + and status.nb_workers_preupload_lfs == 0 ): status.nb_workers_commit += 1 logger.debug("job: commit") @@ -114,6 +116,7 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, logger.debug(f"no task available, waiting... ({WAITING_TIME_IF_NO_TASKS}s)") return (WorkerJob.WAIT, []) + def _get_items_to_commit(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]: """Special case for commit job: the number of items to commit depends on the type of files.""" # Can take at most 50 regular files and/or 100 LFS files in a single commit @@ -137,8 +140,10 @@ def _get_items_to_commit(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]: else: nb_regular += 1 + def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]: return [queue.get()] + def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> List[JOB_ITEM_T]: return [queue.get() for _ in range(min(queue.qsize(), n))] diff --git a/pycsghub/upload_large_folder/local_folder.py b/pycsghub/upload_large_folder/local_folder.py index de2e891..2f7b836 100644 --- a/pycsghub/upload_large_folder/local_folder.py +++ b/pycsghub/upload_large_folder/local_folder.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Optional, Dict + from .fixes import WeakFileLock logger = logging.getLogger(__name__) @@ -22,13 +23,15 @@ key_lfs_uploaded_ids = "lfs_uploaded_ids" key_remote_oid = "remote_oid" + @dataclass(frozen=True) class LocalUploadFilePaths: path_in_repo: str file_path: Path lock_path: Path metadata_path: Path - + + @dataclass class LocalUploadFileMetadata: """Metadata for a file that is being uploaded to the hub.""" @@ -41,20 +44,20 @@ class LocalUploadFileMetadata: upload_mode: Optional[str] = None # regular | lfs is_uploaded: bool = False is_committed: bool = False - - remote_oid: Optional[str] = None # remote oid - lfs_upload_id: Optional[str] = None # upload id, only used for multipart uploads - lfs_uploaded_ids: Optional[str] = None # uploaded ids + + remote_oid: Optional[str] = None # remote oid + lfs_upload_id: Optional[str] = None # upload id, only used for multipart uploads + lfs_uploaded_ids: Optional[str] = None # uploaded ids # only for runtime - lfs_upload_part_count: Optional[int] = None # total number of parts, only used for multipart uploads - lfs_upload_part_index: Optional[int] = None # index of part, only used for multipart uploads - lfs_upload_part_url: Optional[str] = None # upload url for part + lfs_upload_part_count: Optional[int] = None # total number of parts, only used for multipart uploads + lfs_upload_part_index: Optional[int] = None # index of part, only used for multipart uploads + lfs_upload_part_url: Optional[str] = None # upload url for part lfs_upload_chunk_size: Optional[int] = None - lfs_upload_complete_url: Optional[str] = None # for merge multi-part - lfs_upload_verify: Optional[Dict] = None # for verify + lfs_upload_complete_url: Optional[str] = None # for merge multi-part + lfs_upload_verify: Optional[Dict] = None # for verify content_base64: str = "" - + def save(self, paths: LocalUploadFilePaths) -> None: """Save the metadata to disk.""" with WeakFileLock(paths.lock_path): @@ -76,26 +79,27 @@ def save(self, paths: LocalUploadFilePaths) -> None: save_properties(paths.metadata_path, metadata) self.timestamp = new_timestamp + def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetadata: paths = get_local_upload_paths(local_dir, filename) with WeakFileLock(paths.lock_path): if paths.metadata_path.exists(): try: props = read_properties(paths.metadata_path) - + timestamp = float(props.get(key_timestamp)) - + size = int(props.get(key_size)) - + _should_ignore = props.get(key_should_ignore) should_ignore = None if _should_ignore == "" else _should_ignore.lower() == "true" - + sha256 = props.get(key_sha256) sha1 = props.get(key_sha1) - + _upload_mode = props.get(key_upload_mode) upload_mode = None if _upload_mode == "" else _upload_mode - + is_uploaded = props.get(key_is_uploaded).lower() == "true" is_committed = props.get(key_is_committed).lower() == "true" @@ -103,7 +107,7 @@ def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetad lfs_upload_id = None if _lfs_upload_id == "" else _lfs_upload_id _lfs_uploaded_ids = props.get(key_lfs_uploaded_ids) lfs_uploaded_ids = None if _lfs_uploaded_ids == "" else _lfs_uploaded_ids - + metadata = LocalUploadFileMetadata( timestamp=timestamp, size=size, @@ -125,10 +129,10 @@ def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetad logger.warning(f"could not remove corrupted metadata file {paths.metadata_path}: {e}") if ( - metadata.timestamp is not None - and metadata.is_uploaded # file was uploaded - and not metadata.is_committed # but not committed - and time.time() - metadata.timestamp > 20 * 3600 # and it's been more than 20 hours + metadata.timestamp is not None + and metadata.is_uploaded # file was uploaded + and not metadata.is_committed # but not committed + and time.time() - metadata.timestamp > 20 * 3600 # and it's been more than 20 hours ): metadata.is_uploaded = False @@ -144,18 +148,21 @@ def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetad # empty metadata => we don't know anything expect its size return LocalUploadFileMetadata(size=paths.file_path.stat().st_size) + def read_properties(file_path): config = ConfigParser() with open(file_path, 'r', encoding='utf-8') as f: config.read_string(f.read()) return dict(config['DEFAULT']) + def save_properties(file_path: str, data: dict) -> None: config = ConfigParser() config['DEFAULT'] = data with open(file_path, 'w', encoding='utf-8') as f: config.write(f, space_around_delimiters=False) + def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePaths: sanitized_filename = os.path.join(*filename.split("/")) if os.name == "nt": diff --git a/pycsghub/upload_large_folder/main.py b/pycsghub/upload_large_folder/main.py index 9cfc296..386356a 100644 --- a/pycsghub/upload_large_folder/main.py +++ b/pycsghub/upload_large_folder/main.py @@ -1,46 +1,49 @@ -import os import logging +import os +import signal import threading import time from pathlib import Path -from tqdm.auto import tqdm from typing import Optional, Union, List + +from tqdm.auto import tqdm + from pycsghub.cmd.repo_types import RepoType +from pycsghub.constants import DEFAULT_REVISION from pycsghub.constants import REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE +from pycsghub.csghub_api import CsgHubApi from pycsghub.utils import get_endpoint -from .path import filter_repo_objects +from .consts import DEFAULT_IGNORE_PATTERNS from .local_folder import get_local_upload_paths, read_upload_metadata -from .workers import _worker_job +from .path import filter_repo_objects from .status import LargeUploadStatus -from .consts import DEFAULT_IGNORE_PATTERNS -from pycsghub.csghub_api import CsgHubApi -from pycsghub.constants import DEFAULT_REVISION -import os -import signal +from .workers import _worker_job logger = logging.getLogger(__name__) + def upload_large_folder_internal( - repo_id: str, - local_path: str, - repo_type: RepoType, - revision: str, - endpoint: str, - token: str, - allow_patterns: Optional[Union[List[str], str]], - ignore_patterns: Optional[Union[List[str], str]], - num_workers: Optional[int], - print_report: bool, - print_report_every: int, + repo_id: str, + local_path: str, + repo_type: RepoType, + revision: str, + endpoint: str, + token: str, + allow_patterns: Optional[Union[List[str], str]], + ignore_patterns: Optional[Union[List[str], str]], + num_workers: Optional[int], + print_report: bool, + print_report_every: int, ): try: folder_path = Path(local_path).expanduser().resolve() if not folder_path.is_dir(): raise ValueError(f"provided path '{local_path}' is not a directory") - + if repo_type not in [REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]: - raise ValueError(f"invalid repo type, must be one of {REPO_TYPE_MODEL} or {REPO_TYPE_DATASET} or {REPO_TYPE_SPACE}") - + raise ValueError( + f"invalid repo type, must be one of {REPO_TYPE_MODEL} or {REPO_TYPE_DATASET} or {REPO_TYPE_SPACE}") + api_endpoint = get_endpoint(endpoint=endpoint) if ignore_patterns is None: @@ -48,28 +51,29 @@ def upload_large_folder_internal( elif isinstance(ignore_patterns, str): ignore_patterns = [ignore_patterns] ignore_patterns += DEFAULT_IGNORE_PATTERNS - + if num_workers is None: nb_cores = os.cpu_count() or 1 num_workers = max(nb_cores - 2, 2) - + api = CsgHubApi() - - create_repo(api=api, repo_id=repo_id, repo_type=repo_type, revision=revision, endpoint=api_endpoint, token=token) + + create_repo(api=api, repo_id=repo_id, repo_type=repo_type, revision=revision, endpoint=api_endpoint, + token=token) filtered_paths_list = filter_repo_objects( (path.relative_to(folder_path).as_posix() for path in folder_path.glob("**/*") if path.is_file()), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) - + paths_list = [get_local_upload_paths(folder_path, relpath) for relpath in filtered_paths_list] - + items = [ (paths, read_upload_metadata(folder_path, paths.path_in_repo)) for paths in tqdm(paths_list, desc=f"recovering from cache metadata from {folder_path}/.cache") ] - + logger.info(f"starting {num_workers} worker threads for upload tasks") status = LargeUploadStatus(items) threads = [ @@ -87,7 +91,7 @@ def upload_large_folder_internal( ) for _ in range(num_workers) ] - + for thread in threads: thread.start() @@ -112,14 +116,15 @@ def upload_large_folder_internal( except KeyboardInterrupt: print("Terminated by Ctrl+C") os.kill(os.getpid(), signal.SIGTERM) - + + def create_repo( - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): repoExist, branchExist = api.repo_branch_exists( repo_id=repo_id, repo_type=repo_type, revision=revision, diff --git a/pycsghub/upload_large_folder/path.py b/pycsghub/upload_large_folder/path.py index 88f8bd0..ec20d7d 100644 --- a/pycsghub/upload_large_folder/path.py +++ b/pycsghub/upload_large_folder/path.py @@ -4,12 +4,13 @@ T = TypeVar("T") + def filter_repo_objects( - items: Iterable[T], - *, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - key: Optional[Callable[[T], str]] = None, + items: Iterable[T], + *, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + key: Optional[Callable[[T], str]] = None, ) -> Generator[T, None, None]: if isinstance(allow_patterns, str): allow_patterns = [allow_patterns] diff --git a/pycsghub/upload_large_folder/sha.py b/pycsghub/upload_large_folder/sha.py index 5c84e64..3200cd2 100644 --- a/pycsghub/upload_large_folder/sha.py +++ b/pycsghub/upload_large_folder/sha.py @@ -2,10 +2,12 @@ from typing import BinaryIO, Optional, Tuple -from .hashlib import sha1, sha256 from tqdm import tqdm + +from .hashlib import sha1, sha256 from .status import JOB_ITEM_T + def sha_fileobj(fileobj: BinaryIO, item: JOB_ITEM_T, chunk_size: Optional[int] = None) -> Tuple[str, str]: """ Computes the sha256 and sha1 hash of the given file object, by chunks of size `chunk_size`. @@ -26,7 +28,7 @@ def sha_fileobj(fileobj: BinaryIO, item: JOB_ITEM_T, chunk_size: Optional[int] = sha_1 = sha1() header = f'blob {meta.size}\0'.encode('utf-8') sha_1.update(header) - + desc = f"computing sha256 for {paths.file_path}" with tqdm(initial=0, total=meta.size, desc=desc, unit="B", unit_scale=True, dynamic_ncols=True) as pbar: while True: diff --git a/pycsghub/upload_large_folder/slices.py b/pycsghub/upload_large_folder/slices.py index 520f2cf..3e3555c 100644 --- a/pycsghub/upload_large_folder/slices.py +++ b/pycsghub/upload_large_folder/slices.py @@ -1,12 +1,15 @@ import io +import logging +from typing import Dict + import requests from tqdm import tqdm -import logging + from .status import JOB_ITEM_T -from typing import Dict logger = logging.getLogger(__name__) + class UploadTracker(io.BytesIO): def __init__(self, data, progress_bar): super().__init__(data) @@ -17,14 +20,15 @@ def read(self, size=-1): self._progress_bar.update(len(chunk)) return chunk + def slice_upload(item: JOB_ITEM_T): paths, metadata = item upload_desc = f"uploading {paths.file_path}({metadata.lfs_upload_part_index}/{metadata.lfs_upload_part_count})" - + read_chunk_size = metadata.lfs_upload_chunk_size if metadata.lfs_upload_part_index == metadata.lfs_upload_part_count: read_chunk_size = metadata.size - (metadata.lfs_upload_part_count - 1) * metadata.lfs_upload_chunk_size - + chunk_data = None with paths.file_path.open('rb') as f: f.seek((metadata.lfs_upload_part_index - 1) * metadata.lfs_upload_chunk_size) @@ -43,27 +47,31 @@ def slice_upload(item: JOB_ITEM_T): data=upload_data, ) if response.status_code != 200: - logger.error(f"LFS slice {paths.file_path}({metadata.lfs_upload_part_index}/{metadata.lfs_upload_part_count}) upload on {metadata.lfs_upload_part_url} response: {response.text}") + logger.error( + f"LFS slice {paths.file_path}({metadata.lfs_upload_part_index}/{metadata.lfs_upload_part_count}) upload on {metadata.lfs_upload_part_url} response: {response.text}") response.raise_for_status() return response.headers + def slices_upload_complete(item: JOB_ITEM_T, uploaded_ids_map: Dict): paths, metadata = item payload = { "oid": metadata.sha256, "uploadId": metadata.lfs_upload_id, "parts": [ - {"partNumber": i+1, "etag": f"{uploaded_ids_map.get(str(i+1))}"} + {"partNumber": i + 1, "etag": f"{uploaded_ids_map.get(str(i + 1))}"} for i in range(metadata.lfs_upload_part_count) ] } response = requests.post(metadata.lfs_upload_complete_url, json=payload) if response.status_code != 200 and (response.status_code < 400 or response.status_code >= 500): - logger.error(f"LFS {paths.file_path} merge all uploaded slices complete on {metadata.lfs_upload_complete_url} response: {response.text}") + logger.error( + f"LFS {paths.file_path} merge all uploaded slices complete on {metadata.lfs_upload_complete_url} response: {response.text}") if response.status_code < 400 or response.status_code >= 500: response.raise_for_status() return response.text + def slices_upload_verify(item: JOB_ITEM_T): paths, metadata = item payload = { @@ -74,6 +82,7 @@ def slices_upload_verify(item: JOB_ITEM_T): verify_header = metadata.lfs_upload_verify.get("header") response = requests.post(verify_url, headers=verify_header, json=payload) if response.status_code != 200: - logger.error(f"LFS {paths.file_path} slices uploaded verify on {verify_url} response: {response.text}, delete file {paths.metadata_path} and retry") + logger.error( + f"LFS {paths.file_path} slices uploaded verify on {verify_url} response: {response.text}, delete file {paths.metadata_path} and retry") response.raise_for_status() return response.text diff --git a/pycsghub/upload_large_folder/status.py b/pycsghub/upload_large_folder/status.py index 3ee9019..f76811b 100644 --- a/pycsghub/upload_large_folder/status.py +++ b/pycsghub/upload_large_folder/status.py @@ -1,20 +1,23 @@ +import base64 import enum -import queue import logging -from threading import Lock +import queue from datetime import datetime -from typing import List, Optional, Tuple -from .local_folder import LocalUploadFileMetadata, LocalUploadFilePaths -from .consts import REPO_LFS_TYPE, REPO_REGULAR_TYPE -import base64 from io import BytesIO +from threading import Lock +from typing import List, Optional, Tuple + from tqdm import tqdm + from .consts import META_FILE_IDENTIFIER, META_FILE_OID_PREFIX +from .consts import REPO_LFS_TYPE, REPO_REGULAR_TYPE +from .local_folder import LocalUploadFileMetadata, LocalUploadFilePaths logger = logging.getLogger(__name__) JOB_ITEM_T = Tuple[LocalUploadFilePaths, LocalUploadFileMetadata] + class WorkerJob(enum.Enum): SHA256 = enum.auto() GET_UPLOAD_MODE = enum.auto() @@ -23,17 +26,19 @@ class WorkerJob(enum.Enum): COMMIT = enum.auto() WAIT = enum.auto() # if no tasks are available but we don't want to exit + class ProgressReader: def __init__(self, fileobj, progress_bar): self.fileobj = fileobj self.progress_bar = progress_bar - + def read(self, size=-1): data = self.fileobj.read(size) if data: self.progress_bar.update(len(data)) return data + class LargeUploadStatus: """Contains information, queues and tasks for a large upload process.""" @@ -62,16 +67,16 @@ def __init__(self, items: List[JOB_ITEM_T]): for item in self.items: paths, metadata = item self._lfs_uploaded_ids[paths.file_path] = metadata.lfs_uploaded_ids - + if (metadata.upload_mode is not None and metadata.upload_mode == REPO_LFS_TYPE - and metadata.is_uploaded and metadata.is_committed): + and metadata.is_uploaded and metadata.is_committed): num_uploaded_and_commited += 1 elif (metadata.upload_mode is not None and metadata.upload_mode == REPO_REGULAR_TYPE and metadata.is_committed): num_uploaded_and_commited += 1 elif (metadata.sha256 is None or metadata.sha256 == ""): self.queue_sha256.put(item) - elif (metadata.upload_mode is None or metadata.upload_mode == "" + elif (metadata.upload_mode is None or metadata.upload_mode == "" or metadata.remote_oid is None or metadata.remote_oid == ""): self.queue_get_upload_mode.put(item) elif (metadata.upload_mode == REPO_LFS_TYPE and not metadata.is_uploaded): @@ -81,7 +86,7 @@ def __init__(self, items: List[JOB_ITEM_T]): else: num_uploaded_and_commited += 1 logger.debug(f"skipping file {paths.path_in_repo} because they are already uploaded and committed") - + log_msg = "init upload status" if num_uploaded_and_commited > 0: log_msg = f"{log_msg}, found {len(items)} files, {num_uploaded_and_commited} of which are already uploaded and committed" @@ -148,7 +153,7 @@ def current_report(self) -> str: message += f" | queued-slices: {nb_queued_slices}" message += f" | committed: {nb_committed}/{total_files} ({_format_size(size_committed)}/{total_size_str})" message += f" | ignored: {ignored_files}\n" - + message += "Workers: " message += f"hashing: {self.nb_workers_sha256} | " message += f"get upload mode: {self.nb_workers_get_upload_mode} | " @@ -181,7 +186,7 @@ def append_lfs_uploaded_slice_id(self, file_path: str, id: int, etag: str): else: new_ids = old_ids self._lfs_uploaded_ids[file_path] = new_ids - + def convert_uploaded_ids_to_map(self, ids: str): id_map = {} if ids is None or ids == "": @@ -190,18 +195,17 @@ def convert_uploaded_ids_to_map(self, ids: str): idx, etag = item.split(':', 2) id_map[idx] = etag return id_map - - + def is_lfs_upload_completed(self, item: JOB_ITEM_T) -> bool: paths, metadata = item with self.lock: uploaded_ids = self._lfs_uploaded_ids.get(paths.file_path) if (uploaded_ids is not None and - metadata.lfs_upload_id is not None and - metadata.lfs_upload_part_count is not None and - len(uploaded_ids.split(",")) == metadata.lfs_upload_part_count): - metadata.lfs_uploaded_ids = uploaded_ids - return True + metadata.lfs_upload_id is not None and + metadata.lfs_upload_part_count is not None and + len(uploaded_ids.split(",")) == metadata.lfs_upload_part_count): + metadata.lfs_uploaded_ids = uploaded_ids + return True return False def compute_file_base64(self, item: JOB_ITEM_T): @@ -210,19 +214,19 @@ def compute_file_base64(self, item: JOB_ITEM_T): self._compute_lfs_file_base64(item=item) elif meta.upload_mode == REPO_REGULAR_TYPE: self._compute_regular_file_base64(item=item) - + def _compute_lfs_file_base64(self, item: JOB_ITEM_T): _, meta = item content = f"{META_FILE_IDENTIFIER}\n{META_FILE_OID_PREFIX}{meta.sha256}\nsize {meta.size}\n" content_bytes = content.encode('utf-8') meta.content_base64 = base64.b64encode(content_bytes).decode('utf-8') - + def _compute_regular_file_base64(self, item: JOB_ITEM_T): paths, meta = item # with open(paths.file_path, 'rb') as f, BytesIO() as b64_buffer: # base64.encode(f, b64_buffer) # meta.content_base64 = b64_buffer.getvalue().decode("utf-8") - + desc = f"converting {paths.file_path} to base64" with tqdm(total=meta.size, desc=desc, unit="B", unit_scale=True, dynamic_ncols=True) as pbar: with open(paths.file_path, 'rb') as f, BytesIO() as b64_buffer: @@ -231,6 +235,7 @@ def _compute_regular_file_base64(self, item: JOB_ITEM_T): b64_buffer.seek(0) meta.content_base64 = b64_buffer.getvalue().decode("utf-8") + def _format_size(num: int) -> str: """Format size in bytes into a human-readable string. diff --git a/pycsghub/upload_large_folder/utils.py b/pycsghub/upload_large_folder/utils.py index d5f2a1c..d37cdf5 100644 --- a/pycsghub/upload_large_folder/utils.py +++ b/pycsghub/upload_large_folder/utils.py @@ -4,6 +4,7 @@ _hexdig = '0123456789ABCDEFabcdef' _hextobyte = None + def unquote_to_bytes(string): """unquote_to_bytes('abc%20def') -> b'abc def'.""" # Note: strings are encoded as UTF-8. This is only an issue if it contains @@ -34,6 +35,7 @@ def unquote_to_bytes(string): append(item) return b''.join(res) + def unquote(string, encoding='utf-8', errors='replace'): """Replace %xx escapes by their single-character equivalent. The optional encoding and errors parameters specify how to decode percent-encoded diff --git a/pycsghub/upload_large_folder/workers.py b/pycsghub/upload_large_folder/workers.py index ddd6201..55106c8 100644 --- a/pycsghub/upload_large_folder/workers.py +++ b/pycsghub/upload_large_folder/workers.py @@ -1,35 +1,37 @@ +import copy import logging -import traceback import time -import copy +import traceback from typing import Optional, List, Tuple, Dict -from .status import LargeUploadStatus, WorkerJob, JOB_ITEM_T -from .consts import WAITING_TIME_IF_NO_TASKS -from .jobs import _determine_next_job -from .sha import sha_fileobj +from urllib.parse import urlparse, parse_qs + from pycsghub.csghub_api import CsgHubApi -from .utils import unquote from .consts import ( - REPO_REGULAR_TYPE, - REPO_LFS_TYPE, - COMMIT_ACTION_CREATE, + REPO_REGULAR_TYPE, + REPO_LFS_TYPE, + COMMIT_ACTION_CREATE, COMMIT_ACTION_UPDATE, KEY_MSG, MSG_OK, KEY_UPLOADID ) -from urllib.parse import urlparse, parse_qs +from .consts import WAITING_TIME_IF_NO_TASKS +from .jobs import _determine_next_job +from .sha import sha_fileobj from .slices import slice_upload, slices_upload_complete, slices_upload_verify +from .status import LargeUploadStatus, WorkerJob, JOB_ITEM_T +from .utils import unquote logger = logging.getLogger(__name__) + def _worker_job( - status: LargeUploadStatus, - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + status: LargeUploadStatus, + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): """ Main process for a worker. The worker will perform tasks based on the priority list until all files are uploaded @@ -69,9 +71,10 @@ def _worker_job( elif job == WorkerJob.WAIT: _execute_job_waiting(status=status) + def _execute_job_compute_sha256( - items: List[JOB_ITEM_T], - status: LargeUploadStatus, + items: List[JOB_ITEM_T], + status: LargeUploadStatus, ): item = items[0] # single item every time paths, metadata = item @@ -89,15 +92,16 @@ def _execute_job_compute_sha256( with status.lock: status.nb_workers_sha256 -= 1 + def _execute_job_get_upload_model( - items: List[JOB_ITEM_T], - status: LargeUploadStatus, - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + items: List[JOB_ITEM_T], + status: LargeUploadStatus, + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): # A maximum of 50 files at a time try: @@ -124,7 +128,7 @@ def _execute_job_get_upload_model( ignore_num += 1 continue if ((metadata.upload_mode == REPO_REGULAR_TYPE and metadata.sha1 == metadata.remote_oid) or - (metadata.upload_mode == REPO_LFS_TYPE and metadata.sha256 == metadata.remote_oid)): + (metadata.upload_mode == REPO_LFS_TYPE and metadata.sha256 == metadata.remote_oid)): metadata.is_uploaded = True metadata.is_committed = True metadata.save(paths) @@ -139,22 +143,23 @@ def _execute_job_get_upload_model( if ignore_num > 0: logger.info(f"ignored {ignore_num} files because should_ignore is true from remote server") - + if same_with_remote_num > 0: logger.info(f"skipped {same_with_remote_num} files because they are identical to the remote server") - + with status.lock: status.nb_workers_get_upload_mode -= 1 + def _execute_job_pre_upload_lfs( - items: List[JOB_ITEM_T], - status: LargeUploadStatus, - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + items: List[JOB_ITEM_T], + status: LargeUploadStatus, + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): item = items[0] # single item every time paths, metadata = item @@ -180,13 +185,14 @@ def _execute_job_pre_upload_lfs( status.queue_preupload_lfs.put(item) with status.lock: - status.nb_workers_preupload_lfs -= 1 + status.nb_workers_preupload_lfs -= 1 + def _execute_job_uploading_lfs( - items: List[JOB_ITEM_T], - status: LargeUploadStatus, + items: List[JOB_ITEM_T], + status: LargeUploadStatus, ): - item = items[0] # single item every time + item = items[0] # single item every time paths, metadata = item try: etag = _perform_lfs_slice_upload(item) @@ -194,29 +200,31 @@ def _execute_job_uploading_lfs( metadata.lfs_uploaded_ids = status.get_lfs_uploaded_slice_ids(paths.file_path) metadata.save(paths) except Exception as e: - logger.error(f"failed to preupload LFS {paths.file_path} slice {metadata.lfs_upload_part_index}/{metadata.lfs_upload_part_count}: {e}") + logger.error( + f"failed to preupload LFS {paths.file_path} slice {metadata.lfs_upload_part_index}/{metadata.lfs_upload_part_count}: {e}") traceback.format_exc() status.queue_uploading_lfs.put(item) - + with status.lock: status.nb_workers_uploading_lfs -= 1 + def _execute_job_commit( - items: List[JOB_ITEM_T], - status: LargeUploadStatus, - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + items: List[JOB_ITEM_T], + status: LargeUploadStatus, + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): try: for item in items: status.compute_file_base64(item=item) - + _commit(items, api=api, endpoint=endpoint, token=token, - repo_id=repo_id, repo_type=repo_type, revision=revision) + repo_id=repo_id, repo_type=repo_type, revision=revision) logger.info(f"committed {len(items)} items") except KeyboardInterrupt: raise @@ -230,32 +238,35 @@ def _execute_job_commit( status.last_commit_attempt = time.time() status.nb_workers_commit -= 1 + def _execute_job_waiting( - status: LargeUploadStatus, + status: LargeUploadStatus, ): time.sleep(WAITING_TIME_IF_NO_TASKS) with status.lock: - status.nb_workers_waiting -= 1 + status.nb_workers_waiting -= 1 + def _compute_sha256(item: JOB_ITEM_T) -> None: """Compute sha256 of a file and save it in metadata.""" paths, metadata = item if metadata.sha256 is None: with paths.file_path.open("rb") as f: - sha256, sha1 = sha_fileobj(fileobj=f, item=item) - metadata.sha256 = sha256 - metadata.sha1 = sha1 - + sha256, sha1 = sha_fileobj(fileobj=f, item=item) + metadata.sha256 = sha256 + metadata.sha1 = sha1 + metadata.save(paths) + def _get_upload_mode( - items: List[JOB_ITEM_T], - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + items: List[JOB_ITEM_T], + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ) -> None: """Get upload mode for each file and update metadata. @@ -273,21 +284,22 @@ def _get_upload_mode( modes_resp = api.fetch_upload_modes( payload=payload, endpoint=endpoint, token=token, repo_id=repo_id, repo_type=repo_type, revision=revision) - + if modes_resp["data"] is None or modes_resp["data"]["files"] is None: raise ValueError("no correct upload modes response found") - + files_modes = modes_resp["data"]["files"] if not isinstance(files_modes, list): raise ValueError("files is not list in upload modes response") - + if not files_modes and len(files_modes) != len(items): - raise ValueError(f"requested {len(items)} files do not match {len(files_modes)} files in fetch upload modes response") + raise ValueError( + f"requested {len(items)} files do not match {len(files_modes)} files in fetch upload modes response") remote_upload_modes: Dict[str, str] = {} remote_should_ignore: Dict[str, bool] = {} remote_file_oids: Dict[str, Optional[str]] = {} - + for file in files_modes: key = file["path"] if file["isDir"]: @@ -295,45 +307,49 @@ def _get_upload_mode( remote_upload_modes[key] = file["uploadMode"] remote_should_ignore[key] = file["shouldIgnore"] remote_file_oids[key] = file["oid"] - + for item in items: paths, metadata = item metadata.upload_mode = remote_upload_modes[paths.path_in_repo] metadata.should_ignore = remote_should_ignore[paths.path_in_repo] - metadata.remote_oid = None if remote_file_oids[paths.path_in_repo] == "" else remote_file_oids[paths.path_in_repo] + metadata.remote_oid = None if remote_file_oids[paths.path_in_repo] == "" else remote_file_oids[ + paths.path_in_repo] metadata.save(paths) + def _preupload_lfs_done( - item: JOB_ITEM_T, - status: LargeUploadStatus, + item: JOB_ITEM_T, + status: LargeUploadStatus, ): paths, metadata = item uploaded_ids = status.get_lfs_uploaded_slice_ids(paths.file_path) uploaded_ids_map = status.convert_uploaded_ids_to_map(uploaded_ids) complete_resp = slices_upload_complete(item=item, uploaded_ids_map=uploaded_ids_map) - logger.debug(f"LFS file {paths.file_path} merge {metadata.lfs_upload_part_count} uploaded slices complete response: {complete_resp}") + logger.debug( + f"LFS file {paths.file_path} merge {metadata.lfs_upload_part_count} uploaded slices complete response: {complete_resp}") verify_resp = slices_upload_verify(item=item) logger.debug(f"LFS file {paths.file_path} uploaded verify response: {verify_resp}") metadata.is_uploaded = True metadata.save(paths) logger.debug(f"LFS file {paths.file_path} - all {metadata.lfs_upload_part_count} slices uploaded successfully") + def _preupload_lfs( - item: JOB_ITEM_T, - status: LargeUploadStatus, - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + item: JOB_ITEM_T, + status: LargeUploadStatus, + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ): """Preupload LFS file and update metadata.""" paths, metadata = item - + if metadata.lfs_upload_part_count is not None and metadata.lfs_upload_part_count > 0: return - + payload: Dict = { "operation": "upload", "transfers": ["basic", "multipart"], @@ -348,58 +364,64 @@ def _preupload_lfs( } if revision is not None: payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted' - + batch_resp = api.fetch_lfs_batch_info( payload=payload, endpoint=endpoint, token=token, repo_id=repo_id, repo_type=repo_type, revision=revision, local_file=paths.file_path, upload_id=metadata.lfs_upload_id) - + objects = batch_resp.get("objects", None) if not isinstance(objects, list) or len(objects) < 1: - raise ValueError(f"LFS {paths.file_path} malformed batch response objects is not list from server: {batch_resp}") + raise ValueError( + f"LFS {paths.file_path} malformed batch response objects is not list from server: {batch_resp}") object = objects[0] - + search_key = "actions" if not isinstance(object, dict) or search_key not in object: - raise ValueError(f"no slices batch {search_key} info found for response of file {paths.file_path} from server: {object}") + raise ValueError( + f"no slices batch {search_key} info found for response of file {paths.file_path} from server: {object}") object_actions = object[search_key] - + search_key = "upload" if not isinstance(object_actions, dict) or search_key not in object_actions: - raise ValueError(f"no slices batch {search_key} info found for response of file {paths.file_path} from server: {object}") + raise ValueError( + f"no slices batch {search_key} info found for response of file {paths.file_path} from server: {object}") object_upload = object_actions[search_key] - + search_key = "verify" if not isinstance(object_actions, dict) or search_key not in object_actions: - raise ValueError(f"no slices batch {search_key} info found for response of file {paths.file_path} from server: {object}") + raise ValueError( + f"no slices batch {search_key} info found for response of file {paths.file_path} from server: {object}") object_verify = object_actions[search_key] - + search_key = "header" if not isinstance(object_upload, dict) or search_key not in object_upload: - raise ValueError(f"no slices batch {search_key} found for response of file {paths.file_path} from server: {object}") + raise ValueError( + f"no slices batch {search_key} found for response of file {paths.file_path} from server: {object}") object_upload_header = object_upload[search_key] - + href_key = "href" if not isinstance(object_upload, dict) or search_key not in object_upload: - raise ValueError(f"no slices batch merge address found for response of file {paths.file_path} from server: {object}") - + raise ValueError( + f"no slices batch merge address found for response of file {paths.file_path} from server: {object}") + if not isinstance(object_upload_header, Dict): raise ValueError(f"incorrect lfs {paths.file_path} slices upload address from server: {object}") - + chunk_size = object_upload_header.pop("chunk_size") if chunk_size is None: raise ValueError(f"no chunk size found for lfs slices upload of file {paths.file_path}") - + total_count = len(object_upload_header) metadata.lfs_upload_part_count = total_count metadata.lfs_upload_complete_url = object_upload[href_key] metadata.lfs_upload_verify = object_verify - + sorted_keys = sorted(object_upload_header.keys(), key=lambda x: int(x)) parsed_url = urlparse(object_upload_header.get(sorted_keys[0])) query_params = parse_qs(parsed_url.query) metadata.lfs_upload_id = query_params.get(KEY_UPLOADID, [None])[0] - + uploaded_ids = status.get_lfs_uploaded_slice_ids(paths.file_path) existing_ids = status.convert_uploaded_ids_to_map(uploaded_ids) for _, key in enumerate(sorted_keys): @@ -415,6 +437,7 @@ def _preupload_lfs( status.queue_uploading_lfs.put(item_slice) logger.debug(f"get LFS {paths.file_path} slices batch info successfully") + def _perform_lfs_slice_upload(item: JOB_ITEM_T): resp_header = slice_upload(item=item) logger.debug(f"slice upload response header: {resp_header}") @@ -424,17 +447,18 @@ def _perform_lfs_slice_upload(item: JOB_ITEM_T): raise ValueError(f"invalid slice upload response header: {resp_header}, etag: {etag}") return etag.strip('"') + def _commit( - items: List[JOB_ITEM_T], - api: CsgHubApi, - repo_id: str, - repo_type: str, - revision: str, - endpoint: str, - token: str, + items: List[JOB_ITEM_T], + api: CsgHubApi, + repo_id: str, + repo_type: str, + revision: str, + endpoint: str, + token: str, ) -> None: """Commit files to the repo.""" - commit_message="Add files using upload-large-folder tool" + commit_message = "Add files using upload-large-folder tool" payload: Dict = { "message": commit_message, "files": [ @@ -446,14 +470,14 @@ def _commit( for paths, meta in items ] } - + commit_resp = api.create_commit( payload=payload, endpoint=endpoint, token=token, repo_id=repo_id, repo_type=repo_type, revision=revision) - + if commit_resp[KEY_MSG] is None or commit_resp[KEY_MSG] != MSG_OK: raise ValueError(f"create commit response message {commit_resp} is not {MSG_OK}") - + for paths, metadata in items: metadata.is_committed = True metadata.save(paths) diff --git a/pycsghub/utils.py b/pycsghub/utils.py index ff971d4..b8da74a 100644 --- a/pycsghub/utils.py +++ b/pycsghub/utils.py @@ -1,20 +1,28 @@ -from typing import Optional, Union, Dict - -from pathlib import Path +import hashlib import os -from pycsghub.constants import MODEL_ID_SEPARATOR, DEFAULT_CSG_GROUP, DEFAULT_CSGHUB_DOMAIN -from pycsghub.constants import OPERATION_ACTION_API, OPERATION_ACTION_GIT -from pycsghub.constants import REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE -from pycsghub.constants import REPO_SOURCE_CSG, REPO_SOURCE_HF, REPO_SOURCE_MS +import pickle +import shutil +import tempfile +import urllib +from pathlib import Path +from typing import Optional, Union, Dict, Any +from urllib.parse import quote, urlparse +import logging + import requests from huggingface_hub.hf_api import ModelInfo, DatasetInfo, SpaceInfo -import urllib -import hashlib -from pycsghub.errors import FileIntegrityError + from pycsghub._token import _get_token_from_file, _get_token_from_environment -from urllib.parse import quote, urlparse +from pycsghub.constants import MODEL_ID_SEPARATOR, DEFAULT_CSG_GROUP, DEFAULT_CSGHUB_DOMAIN +from pycsghub.constants import OPERATION_ACTION_API, OPERATION_ACTION_GIT +from pycsghub.constants import REPO_SOURCE_CSG, REPO_SOURCE_HF, REPO_SOURCE_MS +from pycsghub.constants import REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE from pycsghub.constants import S3_INTERNAL +from pycsghub.errors import FileIntegrityError +import re + +logger = logging.getLogger(__name__) def get_session() -> requests.Session: session = requests.Session() @@ -27,49 +35,239 @@ def get_session() -> requests.Session: return session -def get_token_to_send(token) -> Optional[str]: +def get_token_to_send(token: Optional[str] = None) -> Optional[str]: + """Get token to send + + Priority: + 1. Explicitly provided token parameter + 2. Environment variable CSGHUB_TOKEN + 3. Configuration file ~/.csghub/token + """ if token: return token - else: - return _get_token_from_environment() or _get_token_from_file() + + # Check environment variable + env_token = os.environ.get("CSGHUB_TOKEN") + if env_token: + return env_token + + # Check configuration file + try: + from pycsghub._token import _get_token_from_file + file_token = _get_token_from_file() + if file_token: + return file_token + except Exception: + pass + + return None def _validate_token_to_send(): pass -def build_csg_headers( - *, - token: Optional[Union[bool, str]] = None, - headers: Optional[Dict[str, str]] = None -) -> Dict[str, str]: - # Get auth token to send - token_to_send = get_token_to_send(token) - csg_headers = {} - # Combine headers - if token_to_send is not None: - csg_headers["authorization"] = f"Bearer {token_to_send}" - if headers is not None: - csg_headers.update(headers) - - csg_headers["X-OPENCSG-S3-Internal"] = S3_INTERNAL - csg_headers["Accept-Encoding"] = None - return csg_headers - - -def model_id_to_group_owner_name(model_id: str) -> (str, str): - if MODEL_ID_SEPARATOR in model_id: - group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0] - name = model_id.split(MODEL_ID_SEPARATOR)[1] - else: - group_or_owner = DEFAULT_CSG_GROUP - name = model_id - return group_or_owner, name +def build_csg_headers(token: Optional[str] = None, headers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Build CSG request headers""" + default_headers = { + "User-Agent": "csghub-sdk", + "Accept": "application/json", + } + + if token: + default_headers["Authorization"] = f"Bearer {token}" + + if headers: + default_headers.update(headers) + + return default_headers + + +def model_id_to_group_owner_name(model_id: str): + """Convert repo ID to group and owner name""" + if "/" not in model_id: + raise ValueError(f"Invalid repo_id format: {model_id}") + + parts = model_id.split("/", 1) + if len(parts) != 2: + raise ValueError(f"Invalid repo_id format: {model_id}") + + namespace, name = parts + return namespace, name + + +def validate_cache_directory(cache_dir: str) -> bool: + """Validate cache directory is available + + Args: + cache_dir (str): cache directory path + + Returns: + bool: cache directory is available + """ + try: + # Windows path length check + if os.name == 'nt': + # Windows has 260 character path limit, need to reserve space + if len(os.path.abspath(cache_dir)) > 240: + print(f"Warning: Cache directory path too long for Windows: {cache_dir}") + return False + + # Check if directory exists, if not, create it + os.makedirs(cache_dir, exist_ok=True) + + # Check write permission + test_file = os.path.join(cache_dir, '.test_write') + with open(test_file, 'w') as f: + f.write('test') + os.remove(test_file) + + # Check disk space (at least 100MB) + free_space = shutil.disk_usage(cache_dir).free + if free_space < 1024 * 1024 * 100: # 100MB + print(f"Warning: Low disk space in cache directory {cache_dir}") + return False + + return True + except (OSError, IOError) as e: + print(f"Warning: Cache directory validation failed for {cache_dir}: {e}") + return False + + +def cleanup_cache_directory(cache_dir: str) -> bool: + """Clean corrupted cache files + + Args: + cache_dir (str): cache directory path + + Returns: + bool: clean cache directory successfully + """ + try: + # Clean corrupted index files + index_file = os.path.join(cache_dir, '.msc') + if os.path.exists(index_file): + try: + with open(index_file, 'rb') as f: + pickle.load(f) + except (pickle.PickleError, EOFError, IOError): + print(f"Warning: Removing corrupted cache index file: {index_file}") + os.remove(index_file) + + # Clean temporary files + for root, dirs, files in os.walk(cache_dir): + for file in files: + if file.endswith('.tmp') or file.startswith('.test_'): + try: + os.remove(os.path.join(root, file)) + except OSError: + pass + + # Clean empty directories + for root, dirs, files in os.walk(cache_dir, topdown=False): + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + try: + if not os.listdir(dir_path): # directory is empty + os.rmdir(dir_path) + except OSError: + pass + + return True + except Exception as e: + print(f"Warning: Cache cleanup failed for {cache_dir}: {e}") + return False + + +def sanitize_path_for_windows(path: str) -> str: + """Clean path, make it available in Windows + + Args: + path (str): original path + + Returns: + str: cleaned path + """ + if os.name == 'nt': + # Windows forbidden characters + invalid_chars = '<>:"|?*' + for char in invalid_chars: + path = path.replace(char, '_') + + # Handle Windows path length limit + if len(path) > 240: + # Use short path or truncate + try: + import win32api + short_path = win32api.GetShortPathName(path) + if len(short_path) <= 240: + return short_path + except ImportError: + pass + + # Truncate path + parts = path.split(os.sep) + while len(path) > 240 and len(parts) > 1: + parts.pop(1) # Keep root directory + path = os.sep.join(parts) + + return path + + +def get_cache_dir_with_fallback(model_id: Optional[str] = None, repo_type: Optional[str] = None) -> str: + """Get cache directory, if failed, use fallback + + Args: + model_id (str, optional): model ID + repo_type (str, optional): repository type + + Returns: + str: available cache directory path + """ + # Get primary cache directory + primary_cache = get_cache_dir(model_id, repo_type) + + # Clean path in Windows + if os.name == 'nt': + primary_cache = sanitize_path_for_windows(primary_cache) + + # Validate primary cache directory + if validate_cache_directory(primary_cache): + return primary_cache + + # Clean cache directory + cleanup_cache_directory(primary_cache) + + # Validate again + if validate_cache_directory(primary_cache): + return primary_cache + + # Use temporary directory as fallback + fallback_cache = os.path.join(tempfile.gettempdir(), 'csg_cache') + if model_id: + # Handle special characters in model ID in Windows + safe_model_id = model_id.replace('/', '_').replace('\\', '_') + fallback_cache = os.path.join(fallback_cache, safe_model_id) + + try: + os.makedirs(fallback_cache, exist_ok=True) + print(f"Warning: Using fallback cache directory: {fallback_cache}") + return fallback_cache + except OSError as e: + print(f"Error: Cannot create fallback cache directory: {e}") + # Last fallback: use current directory + current_cache = os.path.join(os.getcwd(), '.csg_cache') + if model_id: + safe_model_id = model_id.replace('/', '_').replace('\\', '_') + current_cache = os.path.join(current_cache, safe_model_id) + os.makedirs(current_cache, exist_ok=True) + print(f"Warning: Using current directory cache: {current_cache}") + return current_cache def get_cache_dir(model_id: Optional[str] = None, repo_type: Optional[str] = None) -> Union[str, Path]: """cache dir precedence: - function parameter > environment > ~/.cache/csg/hub + function parameter > environment > current directory Args: model_id (str, optional): The model id. @@ -89,22 +287,65 @@ def get_cache_dir(model_id: Optional[str] = None, repo_type: Optional[str] = Non def get_default_cache_dir() -> Path: """ - default base dir: '~/.cache/csg' + default base dir: current directory """ - default_cache_dir = Path.home().joinpath('.cache', 'csg') + default_cache_dir = Path.cwd() return default_cache_dir +def get_model_temp_dir(cache_dir: str, model_id: str) -> str: + # Parse model ID + if '/' in model_id: + owner, name = model_id.split('/', 1) + else: + owner = DEFAULT_CSG_GROUP + name = model_id + + # Handle special characters in Windows + if os.name == 'nt': + # Replace Windows forbidden characters + invalid_chars = '<>:"|?*' + for char in invalid_chars: + owner = owner.replace(char, '_') + name = name.replace(char, '_') + + model_cache_dir = os.path.join(cache_dir, owner, name) + + # Check path length in Windows + if os.name == 'nt' and len(os.path.abspath(model_cache_dir)) > 240: + # Use system temporary directory as fallback + fallback_temp = os.path.join(tempfile.gettempdir(), f'csg_temp_{owner}_{name}') + try: + os.makedirs(fallback_temp, exist_ok=True) + return fallback_temp + except OSError as e: + print(f"Warning: Cannot create fallback temp directory {fallback_temp}: {e}") + # Last fallback: use current directory + current_temp = os.path.join(os.getcwd(), f'.csg_temp_{owner}_{name}') + os.makedirs(current_temp, exist_ok=True) + return current_temp + + try: + os.makedirs(model_cache_dir, exist_ok=True) + return model_cache_dir + except OSError as e: + print(f"Warning: Cannot create model temp directory {model_cache_dir}: {e}") + # Use system temporary directory as fallback + fallback_temp = os.path.join(tempfile.gettempdir(), f'csg_temp_{owner}_{name}') + os.makedirs(fallback_temp, exist_ok=True) + return fallback_temp + + def get_repo_info( - repo_id: str, - *, - revision: Optional[str] = None, - repo_type: Optional[str] = None, - timeout: Optional[float] = None, - files_metadata: bool = False, - token: Union[bool, str, None] = None, - endpoint: Optional[str] = None, - source: Optional[str] = None, + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Union[bool, str, None] = None, + endpoint: Optional[str] = None, + source: Optional[str] = None, ) -> Union[ModelInfo, DatasetInfo, SpaceInfo]: """ Get the info object for a given repo of a given type. @@ -165,14 +406,14 @@ def get_repo_info( def dataset_info( - repo_id: str, - *, - revision: Optional[str] = None, - timeout: Optional[float] = None, - files_metadata: bool = False, - token: Union[bool, str, None] = None, - endpoint: Optional[str] = None, - source: Optional[str] = None, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Union[bool, str, None] = None, + endpoint: Optional[str] = None, + source: Optional[str] = None, ) -> DatasetInfo: """ Get info on one specific dataset on opencsg.com. @@ -212,28 +453,31 @@ def dataset_info( """ headers = build_csg_headers(token=token) - path = get_repo_meta_path(repo_type=REPO_TYPE_DATASET, - repo_id=repo_id, - revision=revision, + path = get_repo_meta_path(repo_type=REPO_TYPE_DATASET, + repo_id=repo_id, + revision=revision, endpoint=endpoint, source=source) params = {} if files_metadata: params["blobs"] = True r = requests.get(path, headers=headers, timeout=timeout, params=params) + if r.status_code != 200: + logger.error(f"get {REPO_TYPE_DATASET} meta info from {path} response: {r.text}") r.raise_for_status() data = r.json() return DatasetInfo(**data) + def space_info( - repo_id: str, - *, - revision: Optional[str] = None, - timeout: Optional[float] = None, - files_metadata: bool = False, - token: Union[bool, str, None] = None, - endpoint: Optional[str] = None, - source: Optional[str] = None, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Union[bool, str, None] = None, + endpoint: Optional[str] = None, + source: Optional[str] = None, ) -> SpaceInfo: """ Get info on one specific space on opencsg.com. @@ -273,29 +517,32 @@ def space_info( """ headers = build_csg_headers(token=token) - path = get_repo_meta_path(repo_type=REPO_TYPE_SPACE, - repo_id=repo_id, - revision=revision, + path = get_repo_meta_path(repo_type=REPO_TYPE_SPACE, + repo_id=repo_id, + revision=revision, endpoint=endpoint, source=source) params = {} if files_metadata: params["blobs"] = True r = requests.get(path, headers=headers, timeout=timeout, params=params) + if r.status_code != 200: + logger.error(f"get {REPO_TYPE_SPACE} meta info from {path} response: {r.text}") r.raise_for_status() data = r.json() return SpaceInfo(**data) + def model_info( - repo_id: str, - *, - revision: Optional[str] = None, - timeout: Optional[float] = None, - securityStatus: Optional[bool] = None, - files_metadata: bool = False, - token: Union[bool, str, None] = None, - endpoint: Optional[str] = None, - source: Optional[str] = None, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + securityStatus: Optional[bool] = None, + files_metadata: bool = False, + token: Union[bool, str, None] = None, + endpoint: Optional[str] = None, + source: Optional[str] = None, ) -> ModelInfo: """ Note: It is a huggingface method moved here to adjust csghub server response. @@ -338,9 +585,9 @@ def model_info( """ headers = build_csg_headers(token=token) - path = get_repo_meta_path(repo_type=REPO_TYPE_MODEL, - repo_id=repo_id, - revision=revision, + path = get_repo_meta_path(repo_type=REPO_TYPE_MODEL, + repo_id=repo_id, + revision=revision, endpoint=endpoint, source=source) params = {} @@ -349,23 +596,26 @@ def model_info( if files_metadata: params["blobs"] = True r = requests.get(path, headers=headers, timeout=timeout, params=params) + if r.status_code != 200: + logger.error(f"get {REPO_TYPE_MODEL} meta info from {path} response: {r.text}") r.raise_for_status() data = r.json() return ModelInfo(**data) + def get_repo_meta_path( - repo_type: str, - repo_id: str, - revision: Optional[str] = None, - endpoint: Optional[str] = None, - source: Optional[str] = None, + repo_type: str, + repo_id: str, + revision: Optional[str] = None, + endpoint: Optional[str] = None, + source: Optional[str] = None, ) -> str: if repo_type != REPO_TYPE_MODEL and repo_type != REPO_TYPE_DATASET and repo_type != REPO_TYPE_SPACE: raise ValueError("repo_type must be one of 'model', 'dataset' or 'space'") - + if source != REPO_SOURCE_CSG and source != REPO_SOURCE_HF and source != REPO_SOURCE_MS and source is not None: raise ValueError("source must be one of 'csg', 'hf' or 'ms'") - + src_prefix = REPO_SOURCE_CSG if source is None else source path = ( f"{endpoint}/{src_prefix}/api/{repo_type}s/{repo_id}/revision/main" @@ -376,12 +626,12 @@ def get_repo_meta_path( def get_file_download_url( - model_id: str, - file_path: str, - revision: str, - repo_type: Optional[str] = None, - endpoint: Optional[str] = None, - source: Optional[str] = None, + model_id: str, + file_path: str, + revision: str, + repo_type: Optional[str] = None, + endpoint: Optional[str] = None, + source: Optional[str] = None, ) -> str: """Format file download url according to `model_id`, `revision` and `file_path`. Args: @@ -395,13 +645,13 @@ def get_file_download_url( file_path = urllib.parse.quote(file_path) revision = urllib.parse.quote(revision) src_prefix = REPO_SOURCE_CSG if source is None else source - + download_url_template = '{endpoint}/{src_prefix}/{model_id}/resolve/{revision}/{file_path}' if repo_type == REPO_TYPE_DATASET: download_url_template = '{endpoint}/{src_prefix}/datasets/{model_id}/resolve/{revision}/{file_path}' elif repo_type == REPO_TYPE_SPACE: download_url_template = '{endpoint}/{src_prefix}/spaces/{model_id}/resolve/{revision}/{file_path}' - + return download_url_template.format( endpoint=endpoint, src_prefix=src_prefix, @@ -419,7 +669,7 @@ def get_endpoint(endpoint: Optional[str] = None, operation: Optional[str] = OPER Returns: str: The formatted endpoint url. """ - + env_csghub_domain = os.getenv('CSGHUB_DOMAIN', None) correct_endpoint = None if bool(endpoint) and endpoint != DEFAULT_CSGHUB_DOMAIN: @@ -473,3 +723,35 @@ def pack_repo_file_info(repo_file_path, repo_file_info = {'Path': repo_file_path, 'Revision': revision} return repo_file_info + + +def contains_chinese(text: str) -> bool: + """ + Check if the string contains Chinese characters + + Args: + text: The string to check + + Returns: + bool: If the string contains Chinese characters, return True, otherwise return False + """ + chinese_pattern = re.compile(r'[\u4e00-\u9fff]') + return bool(chinese_pattern.search(text)) + + +def validate_repo_id(repo_id: str) -> None: + """ + Validate if the repository ID contains Chinese characters + + Args: + repo_id: The repository ID + + Raises: + ValueError: If the repository ID contains Chinese characters + """ + if contains_chinese(repo_id): + raise ValueError( + f"❌ Error: Repository ID '{repo_id}' contains Chinese characters. " + f"Repository names cannot contain Chinese characters. " + f"Please use only English letters, numbers, hyphens, and underscores." + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0ed0a1b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "csghub-sdk" +version = "0.7.5" +description = "CSGHub SDK for downloading and uploading models, datasets, and spaces" +readme = "README.md" +license = { text = "Apache-2.0" } +authors = [ + { name = "opencsg", email = "contact@opencsg.com" } +] +keywords = ["ai", "machine-learning", "models", "datasets", "huggingface"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] +requires-python = ">=3.8,<3.14" +dependencies = [ + "typer", + "typing_extensions", + "huggingface_hub>=0.22.2", +] + +[project.optional-dependencies] +train = [ + "torch", + "transformers>=4.33.3", + "datasets>=2.20.0" +] + +[project.scripts] +csghub-cli = "pycsghub.cli:app" + +[project.urls] +Homepage = "https://github.com/opencsg/csghub-sdk" +Documentation = "https://github.com/opencsg/csghub-sdk" +Repository = "https://github.com/opencsg/csghub-sdk" +Issues = "https://github.com/opencsg/csghub-sdk/issues" + +[tool.setuptools.packages.find] +include = ["pycsghub*"] + +[tool.setuptools.package-data] +pycsghub = ["*"] + +[tool.setuptools.dynamic] +version = { attr = "pycsghub.__version__" } \ No newline at end of file diff --git a/setup.py b/setup.py index 10837a5..8c054b6 100644 --- a/setup.py +++ b/setup.py @@ -1,33 +1,10 @@ -from setuptools import setup, find_packages +""" +CSGHub SDK setup configuration +""" +from setuptools import setup -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() +# This file is now mainly used for backward compatibility +# The main configuration has been moved to pyproject.toml -setup( - name='csghub-sdk', - version='0.7.4', - author="opencsg", - author_email="contact@opencsg.com", - long_description=long_description, - long_description_content_type="text/markdown", - packages=find_packages(include="pycsghub*"), - include_package_data=True, - entry_points={ - "console_scripts": [ - "csghub-cli=pycsghub.cli:app", - ] - }, - install_requires=[ - "typer", - "typing_extensions", - "huggingface_hub>=0.22.2", - ], - extras_require={ - "train": [ - "torch", - "transformers>=4.33.3", - "datasets>=2.20.0" - ], - }, - python_requires=">=3.8,<3.14", -) +if __name__ == "__main__": + setup()