Skip to content

Migrate 2 r2#60

Open
fengzz-coding wants to merge 6 commits intomainfrom
migrate-2-r2
Open

Migrate 2 r2#60
fengzz-coding wants to merge 6 commits intomainfrom
migrate-2-r2

Conversation

@fengzz-coding
Copy link
Contributor

@fengzz-coding fengzz-coding commented Oct 30, 2024

migrate to r2 instead of hg

Summary by CodeRabbit

  • New Features

    • Introduced a new CloudStorage class for cloud storage interactions using Boto3.
    • Added functionality to download files from cloud storage based on specified prefixes.
  • Improvements

    • Enhanced model validation process to support cloud storage parameters.
    • Updated method signatures for improved flexibility in model loading and validation.
  • Bug Fixes

    • Improved error handling during validation to manage specific exceptions effectively.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 30, 2024

Walkthrough

The changes in this pull request introduce a new class CloudStorage in src/core/cloudflare_utils.py for managing cloud storage interactions using the Boto3 library. Additionally, the requirements.txt file is updated to include the boto3 package. The src/validate.py file undergoes several modifications, including updates to method signatures and the addition of parameters related to cloud storage, enhancing the model loading and validation processes.

Changes

File Change Summary
requirements.txt Added package: boto3.
src/core/cloudflare_utils.py Added class: CloudStorage. Methods added: __init__, initialize, and download_files.
src/validate.py Updated method signatures for load_tokenizer and load_model. Modified validate function to include new cloud storage parameters. Enhanced error handling and logic in loop function for model validation processes.

Poem

In the cloud where data flows,
A rabbit hops where Boto goes.
With files to fetch and paths anew,
Our storage dreams are coming true!
So let us validate with glee,
A world of models, wild and free! 🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (2)
src/core/cloudflare_utils.py (1)

6-56: Add docstrings to improve code documentation

The class CloudStorage and its methods lack docstrings. Adding docstrings will enhance code readability and help others understand the purpose and usage of the class and its methods.

Consider adding docstrings like this:

class CloudStorage:
    """
    A class to manage cloud storage interactions using Boto3.
    """

    def __init__(self, access_key=None, secret_key=None, endpoint_url=None, bucket=None, session_token=None):
        """
        Initialize the CloudStorage instance with AWS credentials and configurations.

        Args:
            access_key (str): AWS access key ID.
            secret_key (str): AWS secret access key.
            endpoint_url (str): The endpoint URL for the AWS service.
            bucket (str): The name of the S3 bucket.
            session_token (str, optional): AWS session token for temporary credentials.
        """
        # ...

    def initialize(self):
        """
        Validate the credentials and initialize the Boto3 client.

        Returns:
            CloudStorage: The initialized CloudStorage instance.

        Raises:
            ValueError: If required credentials are missing.
        """
        # ...

    def download_files(self, prefix: str, local_dir: str) -> bool:
        """
        Download files from the specified S3 bucket and prefix to a local directory.

        Args:
            prefix (str): The prefix path in the S3 bucket to download files from.
            local_dir (str): The local directory to save the downloaded files.

        Returns:
            bool: True if files are downloaded successfully, False otherwise.
        """
        # ...
src/validate.py (1)

277-284: Consider encapsulating cloud storage parameters

The validate function now includes several cloud storage-related parameters (access_key, secret_key, endpoint_url, bucket, session_token, prefix). To enhance maintainability and readability, consider encapsulating these parameters into a configuration object or class.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 8366410 and 4a988be.

📒 Files selected for processing (3)
  • requirements.txt (1 hunks)
  • src/core/cloudflare_utils.py (1 hunks)
  • src/validate.py (9 hunks)
✅ Files skipped from review due to trivial changes (1)
  • requirements.txt
🔇 Additional comments (4)
src/validate.py (4)

Line range hint 26-37: Import statements are appropriate

The added imports for CloudStorage and other modules are correct and necessary for the new functionality.


Line range hint 94-106: Ensure 'base_model' parameter is used appropriately

The addition of the base_model parameter to load_tokenizer is properly integrated into the function. The logic for adding special tokens based on the base model is correctly implemented.


300-329: Conditional logic for model loading is correctly implemented

The logic to handle model loading from cloud storage when hg_repo_id is None is correctly implemented. The code properly initializes CloudStorage, downloads the necessary files, and handles failures appropriately.


497-529: Exception handling is comprehensive

The exception handling in the loop includes specific handlers for OSError, RuntimeError, and ValueError, as well as a general exception handler. This approach ensures that different types of errors are appropriately managed.

Comment on lines +28 to +30
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update error message to reflect required parameters

The error message mentions session_token as a required parameter, but session_token is optional for AWS credentials. Also, the current parameter check does not verify if session_token is provided. Update the error message to accurately reflect the required parameters.

Apply this diff to correct the error message:

             logger.error(
-                "Please provide access_key, secret_key, session_token and endpoint_url"
+                "Please provide access_key, secret_key, and endpoint_url"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
logger.error(
"Please provide access_key, secret_key, and endpoint_url"
)

Comment on lines +24 to +31
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure bucket parameter is validated before use

The bucket parameter is used in the download_files method, but there is no validation to ensure it is provided. If self.bucket is None, it will cause a runtime error when making AWS calls. Consider adding a check to validate that bucket is not None.

Apply this diff to include bucket in the parameter checks:

             if (
                 self.access_key is None
                 or self.secret_key is None
                 or self.endpoint_url is None
+                or self.bucket is None
             ):

Update the error message accordingly:

             logger.error(
-                "Please provide access_key, secret_key, and endpoint_url"
+                "Please provide access_key, secret_key, endpoint_url, and bucket"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
or self.bucket is None
):
logger.error(
"Please provide access_key, secret_key, endpoint_url, and bucket"
)
raise

⚠️ Potential issue

Fix incorrect use of raise without specifying an exception

The raise statement on line 31 should specify an exception to raise. Using raise without an exception outside of an exception handler will cause a RuntimeError.

Apply this diff to fix the issue:

             logger.error(
                 "Please provide access_key, secret_key, session_token and endpoint_url"
             )
-            raise
+            raise ValueError("Missing required credentials for CloudStorage")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise ValueError("Missing required credentials for CloudStorage")

Comment on lines +47 to +49
local_file_path = os.path.join(local_dir, file_key)

os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle potential issues with file paths when downloading files

When constructing local_file_path, if file_key includes absolute paths or parent directory references (e.g., starts with / or contains ../), it could lead to security vulnerabilities or overwrite critical files. Ensure that file_key is properly sanitized to prevent path traversal attacks.

Apply this diff to sanitize the file paths:

                 file_key = obj["Key"]
+                # Sanitize file_key to prevent path traversal
+                safe_file_key = os.path.normpath(file_key).lstrip(os.sep)
                 local_file_path = os.path.join(local_dir, safe_file_key)

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +42 to +53
response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket)

if "Contents" in response:
for obj in response["Contents"]:
file_key = obj["Key"]
local_file_path = os.path.join(local_dir, file_key)

os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

self.client.download_file(self.bucket, file_key, local_file_path)
logger.info(f"Downloaded {file_key} to {local_file_path}")
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Handle pagination to download all files when there are more than 1000 objects

The list_objects_v2 method returns up to 1000 objects per request. If there are more than 1000 objects matching the prefix, additional requests are needed to retrieve all objects. Implement pagination to ensure all files are downloaded.

Apply this diff to use a paginator:

-        response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket)

-        if "Contents" in response:
-            for obj in response["Contents"]:
+        paginator = self.client.get_paginator('list_objects_v2')
+        for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
+            if 'Contents' in page:
+                for obj in page['Contents']:
                     file_key = obj["Key"]
                     local_file_path = os.path.join(local_dir, file_key)

                     os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

                     self.client.download_file(self.bucket, file_key, local_file_path)
                     logger.info(f"Downloaded {file_key} to {local_file_path}")
             return True

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +135 to 142
with open(os.path.join(model_path, "adapter_config.json"), "r") as f:
adapter_config = json.load(f)
base_model = adapter_config["base_model_name_or_path"]
model = AutoModelForCausalLM.from_pretrained(
base_model, token=HF_TOKEN, **model_kwargs
)
# download the adapter weights
download_lora_repo(model_name_or_path, revision)

model = PeftModel.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incorrect parameter 'token' in from_pretrained calls

In the calls to AutoModelForCausalLM.from_pretrained, the argument token=HF_TOKEN is used. The correct parameter name is use_auth_token. Using token will not pass the authentication token, which might cause authentication failures when accessing private models or models requiring authentication.

Apply this diff to fix the issue:

# First occurrence
            model = AutoModelForCausalLM.from_pretrained(
-                base_model, token=HF_TOKEN, **model_kwargs
+                base_model, use_auth_token=HF_TOKEN, **model_kwargs
            )

# Second occurrence
            model = AutoModelForCausalLM.from_pretrained(
-                model_path, token=HF_TOKEN, **model_kwargs
+                model_path, use_auth_token=HF_TOKEN, **model_kwargs
            )

Also applies to: 158-160

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant