Conversation
WalkthroughThe changes in this pull request introduce a new class Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (2)
src/core/cloudflare_utils.py (1)
6-56: Add docstrings to improve code documentationThe class
CloudStorageand its methods lack docstrings. Adding docstrings will enhance code readability and help others understand the purpose and usage of the class and its methods.Consider adding docstrings like this:
class CloudStorage: """ A class to manage cloud storage interactions using Boto3. """ def __init__(self, access_key=None, secret_key=None, endpoint_url=None, bucket=None, session_token=None): """ Initialize the CloudStorage instance with AWS credentials and configurations. Args: access_key (str): AWS access key ID. secret_key (str): AWS secret access key. endpoint_url (str): The endpoint URL for the AWS service. bucket (str): The name of the S3 bucket. session_token (str, optional): AWS session token for temporary credentials. """ # ... def initialize(self): """ Validate the credentials and initialize the Boto3 client. Returns: CloudStorage: The initialized CloudStorage instance. Raises: ValueError: If required credentials are missing. """ # ... def download_files(self, prefix: str, local_dir: str) -> bool: """ Download files from the specified S3 bucket and prefix to a local directory. Args: prefix (str): The prefix path in the S3 bucket to download files from. local_dir (str): The local directory to save the downloaded files. Returns: bool: True if files are downloaded successfully, False otherwise. """ # ...src/validate.py (1)
277-284: Consider encapsulating cloud storage parametersThe
validatefunction now includes several cloud storage-related parameters (access_key,secret_key,endpoint_url,bucket,session_token,prefix). To enhance maintainability and readability, consider encapsulating these parameters into a configuration object or class.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
requirements.txt(1 hunks)src/core/cloudflare_utils.py(1 hunks)src/validate.py(9 hunks)
✅ Files skipped from review due to trivial changes (1)
- requirements.txt
🔇 Additional comments (4)
src/validate.py (4)
Line range hint 26-37: Import statements are appropriate
The added imports for CloudStorage and other modules are correct and necessary for the new functionality.
Line range hint 94-106: Ensure 'base_model' parameter is used appropriately
The addition of the base_model parameter to load_tokenizer is properly integrated into the function. The logic for adding special tokens based on the base model is correctly implemented.
300-329: Conditional logic for model loading is correctly implemented
The logic to handle model loading from cloud storage when hg_repo_id is None is correctly implemented. The code properly initializes CloudStorage, downloads the necessary files, and handles failures appropriately.
497-529: Exception handling is comprehensive
The exception handling in the loop includes specific handlers for OSError, RuntimeError, and ValueError, as well as a general exception handler. This approach ensures that different types of errors are appropriately managed.
| logger.error( | ||
| "Please provide access_key, secret_key, session_token and endpoint_url" | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Update error message to reflect required parameters
The error message mentions session_token as a required parameter, but session_token is optional for AWS credentials. Also, the current parameter check does not verify if session_token is provided. Update the error message to accurately reflect the required parameters.
Apply this diff to correct the error message:
logger.error(
- "Please provide access_key, secret_key, session_token and endpoint_url"
+ "Please provide access_key, secret_key, and endpoint_url"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| logger.error( | |
| "Please provide access_key, secret_key, session_token and endpoint_url" | |
| ) | |
| logger.error( | |
| "Please provide access_key, secret_key, and endpoint_url" | |
| ) |
| self.access_key is None | ||
| or self.secret_key is None | ||
| or self.endpoint_url is None | ||
| ): | ||
| logger.error( | ||
| "Please provide access_key, secret_key, session_token and endpoint_url" | ||
| ) | ||
| raise |
There was a problem hiding this comment.
Ensure bucket parameter is validated before use
The bucket parameter is used in the download_files method, but there is no validation to ensure it is provided. If self.bucket is None, it will cause a runtime error when making AWS calls. Consider adding a check to validate that bucket is not None.
Apply this diff to include bucket in the parameter checks:
if (
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
+ or self.bucket is None
):Update the error message accordingly:
logger.error(
- "Please provide access_key, secret_key, and endpoint_url"
+ "Please provide access_key, secret_key, endpoint_url, and bucket"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.access_key is None | |
| or self.secret_key is None | |
| or self.endpoint_url is None | |
| ): | |
| logger.error( | |
| "Please provide access_key, secret_key, session_token and endpoint_url" | |
| ) | |
| raise | |
| self.access_key is None | |
| or self.secret_key is None | |
| or self.endpoint_url is None | |
| or self.bucket is None | |
| ): | |
| logger.error( | |
| "Please provide access_key, secret_key, endpoint_url, and bucket" | |
| ) | |
| raise |
Fix incorrect use of raise without specifying an exception
The raise statement on line 31 should specify an exception to raise. Using raise without an exception outside of an exception handler will cause a RuntimeError.
Apply this diff to fix the issue:
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
- raise
+ raise ValueError("Missing required credentials for CloudStorage")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.access_key is None | |
| or self.secret_key is None | |
| or self.endpoint_url is None | |
| ): | |
| logger.error( | |
| "Please provide access_key, secret_key, session_token and endpoint_url" | |
| ) | |
| raise | |
| self.access_key is None | |
| or self.secret_key is None | |
| or self.endpoint_url is None | |
| ): | |
| logger.error( | |
| "Please provide access_key, secret_key, session_token and endpoint_url" | |
| ) | |
| raise ValueError("Missing required credentials for CloudStorage") |
| local_file_path = os.path.join(local_dir, file_key) | ||
|
|
||
| os.makedirs(os.path.dirname(local_file_path), exist_ok=True) |
There was a problem hiding this comment.
Handle potential issues with file paths when downloading files
When constructing local_file_path, if file_key includes absolute paths or parent directory references (e.g., starts with / or contains ../), it could lead to security vulnerabilities or overwrite critical files. Ensure that file_key is properly sanitized to prevent path traversal attacks.
Apply this diff to sanitize the file paths:
file_key = obj["Key"]
+ # Sanitize file_key to prevent path traversal
+ safe_file_key = os.path.normpath(file_key).lstrip(os.sep)
local_file_path = os.path.join(local_dir, safe_file_key)Committable suggestion skipped: line range outside the PR's diff.
| response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket) | ||
|
|
||
| if "Contents" in response: | ||
| for obj in response["Contents"]: | ||
| file_key = obj["Key"] | ||
| local_file_path = os.path.join(local_dir, file_key) | ||
|
|
||
| os.makedirs(os.path.dirname(local_file_path), exist_ok=True) | ||
|
|
||
| self.client.download_file(self.bucket, file_key, local_file_path) | ||
| logger.info(f"Downloaded {file_key} to {local_file_path}") | ||
| return True |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Handle pagination to download all files when there are more than 1000 objects
The list_objects_v2 method returns up to 1000 objects per request. If there are more than 1000 objects matching the prefix, additional requests are needed to retrieve all objects. Implement pagination to ensure all files are downloaded.
Apply this diff to use a paginator:
- response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket)
- if "Contents" in response:
- for obj in response["Contents"]:
+ paginator = self.client.get_paginator('list_objects_v2')
+ for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
+ if 'Contents' in page:
+ for obj in page['Contents']:
file_key = obj["Key"]
local_file_path = os.path.join(local_dir, file_key)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
self.client.download_file(self.bucket, file_key, local_file_path)
logger.info(f"Downloaded {file_key} to {local_file_path}")
return TrueCommittable suggestion skipped: line range outside the PR's diff.
| with open(os.path.join(model_path, "adapter_config.json"), "r") as f: | ||
| adapter_config = json.load(f) | ||
| base_model = adapter_config["base_model_name_or_path"] | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| base_model, token=HF_TOKEN, **model_kwargs | ||
| ) | ||
| # download the adapter weights | ||
| download_lora_repo(model_name_or_path, revision) | ||
|
|
||
| model = PeftModel.from_pretrained( |
There was a problem hiding this comment.
Incorrect parameter 'token' in from_pretrained calls
In the calls to AutoModelForCausalLM.from_pretrained, the argument token=HF_TOKEN is used. The correct parameter name is use_auth_token. Using token will not pass the authentication token, which might cause authentication failures when accessing private models or models requiring authentication.
Apply this diff to fix the issue:
# First occurrence
model = AutoModelForCausalLM.from_pretrained(
- base_model, token=HF_TOKEN, **model_kwargs
+ base_model, use_auth_token=HF_TOKEN, **model_kwargs
)
# Second occurrence
model = AutoModelForCausalLM.from_pretrained(
- model_path, token=HF_TOKEN, **model_kwargs
+ model_path, use_auth_token=HF_TOKEN, **model_kwargs
)Also applies to: 158-160
migrate to r2 instead of hg
Summary by CodeRabbit
New Features
CloudStorageclass for cloud storage interactions using Boto3.Improvements
Bug Fixes