Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ python-dotenv
peft>=0.10.0
gitpython
pre-commit
boto3
56 changes: 56 additions & 0 deletions src/core/cloudflare_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import boto3
from loguru import logger


class CloudStorage:
def __init__(
self,
access_key=None,
secret_key=None,
endpoint_url=None,
bucket=None,
session_token=None,
):
self.access_key = access_key
self.secret_key = secret_key
self.endpoint_url = endpoint_url
self.bucket = bucket
self.client = None
self.session_token = session_token

def initialize(self):
if (
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
Comment on lines +28 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update error message to reflect required parameters

The error message mentions session_token as a required parameter, but session_token is optional for AWS credentials. Also, the current parameter check does not verify if session_token is provided. Update the error message to accurately reflect the required parameters.

Apply this diff to correct the error message:

             logger.error(
-                "Please provide access_key, secret_key, session_token and endpoint_url"
+                "Please provide access_key, secret_key, and endpoint_url"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
logger.error(
"Please provide access_key, secret_key, and endpoint_url"
)

raise
Comment on lines +24 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure bucket parameter is validated before use

The bucket parameter is used in the download_files method, but there is no validation to ensure it is provided. If self.bucket is None, it will cause a runtime error when making AWS calls. Consider adding a check to validate that bucket is not None.

Apply this diff to include bucket in the parameter checks:

             if (
                 self.access_key is None
                 or self.secret_key is None
                 or self.endpoint_url is None
+                or self.bucket is None
             ):

Update the error message accordingly:

             logger.error(
-                "Please provide access_key, secret_key, and endpoint_url"
+                "Please provide access_key, secret_key, endpoint_url, and bucket"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
or self.bucket is None
):
logger.error(
"Please provide access_key, secret_key, endpoint_url, and bucket"
)
raise

⚠️ Potential issue

Fix incorrect use of raise without specifying an exception

The raise statement on line 31 should specify an exception to raise. Using raise without an exception outside of an exception handler will cause a RuntimeError.

Apply this diff to fix the issue:

             logger.error(
                 "Please provide access_key, secret_key, session_token and endpoint_url"
             )
-            raise
+            raise ValueError("Missing required credentials for CloudStorage")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise ValueError("Missing required credentials for CloudStorage")

self.client = boto3.client(
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
aws_session_token=self.session_token,
)
return self

def download_files(self, prefix: str, local_dir: str) -> bool:
response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket)

if "Contents" in response:
for obj in response["Contents"]:
file_key = obj["Key"]
local_file_path = os.path.join(local_dir, file_key)

os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
Comment on lines +47 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle potential issues with file paths when downloading files

When constructing local_file_path, if file_key includes absolute paths or parent directory references (e.g., starts with / or contains ../), it could lead to security vulnerabilities or overwrite critical files. Ensure that file_key is properly sanitized to prevent path traversal attacks.

Apply this diff to sanitize the file paths:

                 file_key = obj["Key"]
+                # Sanitize file_key to prevent path traversal
+                safe_file_key = os.path.normpath(file_key).lstrip(os.sep)
                 local_file_path = os.path.join(local_dir, safe_file_key)

Committable suggestion skipped: line range outside the PR's diff.


self.client.download_file(self.bucket, file_key, local_file_path)
logger.info(f"Downloaded {file_key} to {local_file_path}")
return True
Comment on lines +42 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Handle pagination to download all files when there are more than 1000 objects

The list_objects_v2 method returns up to 1000 objects per request. If there are more than 1000 objects matching the prefix, additional requests are needed to retrieve all objects. Implement pagination to ensure all files are downloaded.

Apply this diff to use a paginator:

-        response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket)

-        if "Contents" in response:
-            for obj in response["Contents"]:
+        paginator = self.client.get_paginator('list_objects_v2')
+        for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
+            if 'Contents' in page:
+                for obj in page['Contents']:
                     file_key = obj["Key"]
                     local_file_path = os.path.join(local_dir, file_key)

                     os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

                     self.client.download_file(self.bucket, file_key, local_file_path)
                     logger.info(f"Downloaded {file_key} to {local_file_path}")
             return True

Committable suggestion skipped: line range outside the PR's diff.

else:
logger.info("No files found in the bucket.")
return False
114 changes: 82 additions & 32 deletions src/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from core.collator import SFTDataCollator
from core.dataset import UnifiedSFTDataset
from core.template import template_dict
from core.hf_utils import download_lora_config, download_lora_repo
from core.hf_utils import download_lora_repo
from core.gpu_utils import get_gpu_type
from core.constant import SUPPORTED_BASE_MODELS
from core.exception import (
Expand All @@ -34,6 +34,7 @@
from tenacity import retry, stop_after_attempt, wait_exponential
from client.fed_ledger import FedLedger
from peft import PeftModel
from core.cloudflare_utils import CloudStorage
import sys

load_dotenv()
Expand Down Expand Up @@ -90,12 +91,12 @@ def download_file(url):
raise e


def load_tokenizer(model_name_or_path: str) -> AutoTokenizer:
def load_tokenizer(model_name_or_path: str, base_model: str) -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=True,
)
if "gemma" in model_name_or_path.lower():
if "gemma" in base_model.lower():
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<start_of_turn>", "<end_of_turn>"]}
)
Expand All @@ -113,9 +114,9 @@ def load_tokenizer(model_name_or_path: str) -> AutoTokenizer:


def load_model(
model_name_or_path: str, lora_only: bool, revision: str, val_args: TrainingArguments
model_path: str, lora_only: bool, val_args: TrainingArguments
) -> Trainer:
logger.info(f"Loading model from base model: {model_name_or_path}")
logger.info(f"Loading model from base model: {model_path}")

if val_args.use_cpu:
torch_dtype = torch.float32
Expand All @@ -128,19 +129,19 @@ def load_model(
device_map=None,
)
# check whether it is a lora weight
if download_lora_config(model_name_or_path, revision):

if os.path.isfile(os.path.join(model_path, "adapter_config.json")):
logger.info("Repo is a lora weight, loading model with adapter weights")
with open("lora/adapter_config.json", "r") as f:
with open(os.path.join(model_path, "adapter_config.json"), "r") as f:
adapter_config = json.load(f)
base_model = adapter_config["base_model_name_or_path"]
model = AutoModelForCausalLM.from_pretrained(
base_model, token=HF_TOKEN, **model_kwargs
)
# download the adapter weights
download_lora_repo(model_name_or_path, revision)

model = PeftModel.from_pretrained(
Comment on lines +135 to 142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incorrect parameter 'token' in from_pretrained calls

In the calls to AutoModelForCausalLM.from_pretrained, the argument token=HF_TOKEN is used. The correct parameter name is use_auth_token. Using token will not pass the authentication token, which might cause authentication failures when accessing private models or models requiring authentication.

Apply this diff to fix the issue:

# First occurrence
            model = AutoModelForCausalLM.from_pretrained(
-                base_model, token=HF_TOKEN, **model_kwargs
+                base_model, use_auth_token=HF_TOKEN, **model_kwargs
            )

# Second occurrence
            model = AutoModelForCausalLM.from_pretrained(
-                model_path, token=HF_TOKEN, **model_kwargs
+                model_path, use_auth_token=HF_TOKEN, **model_kwargs
            )

Also applies to: 158-160

model,
"lora",
model_path,
device_map=None,
)
model = model.merge_and_unload()
Expand All @@ -154,7 +155,7 @@ def load_model(
return None
logger.info("Repo is a full fine-tuned model, loading model directly")
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, token=HF_TOKEN, **model_kwargs
model_path, token=HF_TOKEN, **model_kwargs
)

if "output_router_logits" in model.config.to_dict():
Expand Down Expand Up @@ -244,7 +245,6 @@ def cli():


@click.command()
@click.option("--model_name_or_path", required=True, type=str, help="")
@click.option("--base_model", required=True, type=str, help="")
@click.option("--eval_file", default="./data/dummy_data.jsonl", type=str, help="")
@click.option("--context_length", required=True, type=int)
Expand All @@ -266,7 +266,6 @@ def cli():
help="Run the script in local test mode to avoid submitting to the server",
)
def validate(
model_name_or_path: str,
base_model: str,
eval_file: str,
context_length: int,
Expand All @@ -275,7 +274,14 @@ def validate(
assignment_id: str = None,
local_test: bool = False,
lora_only: bool = True,
hg_repo_id: str = None,
revision: str = "main",
access_key: str = None,
secret_key: str = None,
endpoint_url: str = None,
bucket: str = None,
session_token: str = None,
prefix: str = None,
):
if not local_test and assignment_id is None:
raise ValueError(
Expand All @@ -291,11 +297,36 @@ def validate(
val_args = parser.parse_json_file(json_file=validation_args_file)[0]
gpu_type = get_gpu_type()

tokenizer = load_tokenizer(model_name_or_path)
eval_dataset = load_sft_dataset(
eval_file, context_length, template_name=base_model, tokenizer=tokenizer
)
model = load_model(model_name_or_path, lora_only, revision, val_args)
if hg_repo_id is None:
cf_storage = CloudStorage(
access_key=access_key,
secret_key=secret_key,
endpoint_url=endpoint_url,
bucket=bucket,
session_token=session_token,
)
cf_storage.initialize()
cf_download_result = cf_storage.download_files(
prefix=prefix, local_dir="lora"
)
if not cf_download_result:
fed_ledger.mark_assignment_as_failed(assignment_id)
return
lora_model_path = os.path.join("lora", prefix)
tokenizer = load_tokenizer(
model_name_or_path=lora_model_path, base_model=base_model
)
eval_dataset = load_sft_dataset(
eval_file, context_length, template_name=base_model, tokenizer=tokenizer
)
model = load_model(lora_model_path, lora_only, val_args)
else:
tokenizer = load_tokenizer(hg_repo_id, base_model=base_model)
eval_dataset = load_sft_dataset(
eval_file, context_length, template_name=base_model, tokenizer=tokenizer
)
download_lora_repo(hg_repo_id, revision)
model = load_model("lora", lora_only, val_args)
# if model is not loaded, mark the assignment as failed and return
if model is None:
fed_ledger.mark_assignment_as_failed(assignment_id)
Expand Down Expand Up @@ -458,25 +489,44 @@ def loop(
continue
resp = resp.json()
eval_file = download_file(resp["data"]["validation_set_url"])
revision = resp["task_submission"]["data"].get("revision", "main")
assignment_id = resp["id"]

for attempt in range(3):
try:
ctx = click.Context(validate)
ctx.invoke(
validate,
model_name_or_path=resp["task_submission"]["data"]["hg_repo_id"],
base_model=resp["data"]["base_model"],
eval_file=eval_file,
context_length=resp["data"]["context_length"],
max_params=resp["data"]["max_params"],
validation_args_file=validation_args_file,
assignment_id=resp["id"],
local_test=False,
lora_only=lora_only,
revision=revision,
)
if "hg_repo_id" in resp["task_submission"]["data"]:
revision = resp["task_submission"]["data"].get("revision", "main")
ctx.invoke(
validate,
hg_repo_id=resp["task_submission"]["data"]["hg_repo_id"],
base_model=resp["data"]["base_model"],
eval_file=eval_file,
context_length=resp["data"]["context_length"],
max_params=resp["data"]["max_params"],
validation_args_file=validation_args_file,
assignment_id=resp["id"],
local_test=False,
lora_only=lora_only,
revision=revision,
)
else:
ctx.invoke(
validate,
base_model=resp["data"]["base_model"],
eval_file=eval_file,
context_length=resp["data"]["context_length"],
max_params=resp["data"]["max_params"],
validation_args_file=validation_args_file,
assignment_id=resp["id"],
local_test=False,
lora_only=lora_only,
access_key=resp["data"]["access_key"],
secret_key=resp["data"]["secret_key"],
endpoint_url=resp["data"]["endpoint_url"],
bucket=resp["data"]["bucket"],
session_token=resp["data"]["session_token"],
prefix=resp["data"]["prefix"],
)
break # Break the loop if no exception
except KeyboardInterrupt:
# directly terminate the process if keyboard interrupt
Expand Down