-
Notifications
You must be signed in to change notification settings - Fork 40
Migrate 2 r2 #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Migrate 2 r2 #60
Changes from all commits
c447b77
4fd5329
6dfc146
e22f8b7
29ecfe9
4a988be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,3 +16,4 @@ python-dotenv | |
| peft>=0.10.0 | ||
| gitpython | ||
| pre-commit | ||
| boto3 | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,56 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import boto3 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from loguru import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class CloudStorage: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| access_key=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| secret_key=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| endpoint_url=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_token=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.access_key = access_key | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.secret_key = secret_key | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.endpoint_url = endpoint_url | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.bucket = bucket | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.session_token = session_token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def initialize(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.access_key is None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| or self.secret_key is None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| or self.endpoint_url is None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Please provide access_key, secret_key, session_token and endpoint_url" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+24
to
+31
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure The Apply this diff to include if (
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
+ or self.bucket is None
):Update the error message accordingly: logger.error(
- "Please provide access_key, secret_key, and endpoint_url"
+ "Please provide access_key, secret_key, endpoint_url, and bucket"
)📝 Committable suggestion
Suggested change
Fix incorrect use of The Apply this diff to fix the issue: logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
- raise
+ raise ValueError("Missing required credentials for CloudStorage")📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client = boto3.client( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "s3", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| endpoint_url=self.endpoint_url, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| aws_access_key_id=self.access_key, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| aws_secret_access_key=self.secret_key, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| aws_session_token=self.session_token, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def download_files(self, prefix: str, local_dir: str) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "Contents" in response: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for obj in response["Contents"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| file_key = obj["Key"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_file_path = os.path.join(local_dir, file_key) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os.makedirs(os.path.dirname(local_file_path), exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+47
to
+49
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle potential issues with file paths when downloading files When constructing Apply this diff to sanitize the file paths: file_key = obj["Key"]
+ # Sanitize file_key to prevent path traversal
+ safe_file_key = os.path.normpath(file_key).lstrip(os.sep)
local_file_path = os.path.join(local_dir, safe_file_key)
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client.download_file(self.bucket, file_key, local_file_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"Downloaded {file_key} to {local_file_path}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+42
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Handle pagination to download all files when there are more than 1000 objects The Apply this diff to use a paginator: - response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket)
- if "Contents" in response:
- for obj in response["Contents"]:
+ paginator = self.client.get_paginator('list_objects_v2')
+ for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
+ if 'Contents' in page:
+ for obj in page['Contents']:
file_key = obj["Key"]
local_file_path = os.path.join(local_dir, file_key)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
self.client.download_file(self.bucket, file_key, local_file_path)
logger.info(f"Downloaded {file_key} to {local_file_path}")
return True
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("No files found in the bucket.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
| from core.collator import SFTDataCollator | ||
| from core.dataset import UnifiedSFTDataset | ||
| from core.template import template_dict | ||
| from core.hf_utils import download_lora_config, download_lora_repo | ||
| from core.hf_utils import download_lora_repo | ||
| from core.gpu_utils import get_gpu_type | ||
| from core.constant import SUPPORTED_BASE_MODELS | ||
| from core.exception import ( | ||
|
|
@@ -34,6 +34,7 @@ | |
| from tenacity import retry, stop_after_attempt, wait_exponential | ||
| from client.fed_ledger import FedLedger | ||
| from peft import PeftModel | ||
| from core.cloudflare_utils import CloudStorage | ||
| import sys | ||
|
|
||
| load_dotenv() | ||
|
|
@@ -90,12 +91,12 @@ def download_file(url): | |
| raise e | ||
|
|
||
|
|
||
| def load_tokenizer(model_name_or_path: str) -> AutoTokenizer: | ||
| def load_tokenizer(model_name_or_path: str, base_model: str) -> AutoTokenizer: | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| model_name_or_path, | ||
| use_fast=True, | ||
| ) | ||
| if "gemma" in model_name_or_path.lower(): | ||
| if "gemma" in base_model.lower(): | ||
| tokenizer.add_special_tokens( | ||
| {"additional_special_tokens": ["<start_of_turn>", "<end_of_turn>"]} | ||
| ) | ||
|
|
@@ -113,9 +114,9 @@ def load_tokenizer(model_name_or_path: str) -> AutoTokenizer: | |
|
|
||
|
|
||
| def load_model( | ||
| model_name_or_path: str, lora_only: bool, revision: str, val_args: TrainingArguments | ||
| model_path: str, lora_only: bool, val_args: TrainingArguments | ||
| ) -> Trainer: | ||
| logger.info(f"Loading model from base model: {model_name_or_path}") | ||
| logger.info(f"Loading model from base model: {model_path}") | ||
|
|
||
| if val_args.use_cpu: | ||
| torch_dtype = torch.float32 | ||
|
|
@@ -128,19 +129,19 @@ def load_model( | |
| device_map=None, | ||
| ) | ||
| # check whether it is a lora weight | ||
| if download_lora_config(model_name_or_path, revision): | ||
|
|
||
| if os.path.isfile(os.path.join(model_path, "adapter_config.json")): | ||
| logger.info("Repo is a lora weight, loading model with adapter weights") | ||
| with open("lora/adapter_config.json", "r") as f: | ||
| with open(os.path.join(model_path, "adapter_config.json"), "r") as f: | ||
| adapter_config = json.load(f) | ||
| base_model = adapter_config["base_model_name_or_path"] | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| base_model, token=HF_TOKEN, **model_kwargs | ||
| ) | ||
| # download the adapter weights | ||
| download_lora_repo(model_name_or_path, revision) | ||
|
|
||
| model = PeftModel.from_pretrained( | ||
|
Comment on lines
+135
to
142
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Incorrect parameter 'token' in In the calls to Apply this diff to fix the issue: # First occurrence
model = AutoModelForCausalLM.from_pretrained(
- base_model, token=HF_TOKEN, **model_kwargs
+ base_model, use_auth_token=HF_TOKEN, **model_kwargs
)
# Second occurrence
model = AutoModelForCausalLM.from_pretrained(
- model_path, token=HF_TOKEN, **model_kwargs
+ model_path, use_auth_token=HF_TOKEN, **model_kwargs
)Also applies to: 158-160 |
||
| model, | ||
| "lora", | ||
| model_path, | ||
| device_map=None, | ||
| ) | ||
| model = model.merge_and_unload() | ||
|
|
@@ -154,7 +155,7 @@ def load_model( | |
| return None | ||
| logger.info("Repo is a full fine-tuned model, loading model directly") | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_name_or_path, token=HF_TOKEN, **model_kwargs | ||
| model_path, token=HF_TOKEN, **model_kwargs | ||
| ) | ||
|
|
||
| if "output_router_logits" in model.config.to_dict(): | ||
|
|
@@ -244,7 +245,6 @@ def cli(): | |
|
|
||
|
|
||
| @click.command() | ||
| @click.option("--model_name_or_path", required=True, type=str, help="") | ||
| @click.option("--base_model", required=True, type=str, help="") | ||
| @click.option("--eval_file", default="./data/dummy_data.jsonl", type=str, help="") | ||
| @click.option("--context_length", required=True, type=int) | ||
|
|
@@ -266,7 +266,6 @@ def cli(): | |
| help="Run the script in local test mode to avoid submitting to the server", | ||
| ) | ||
| def validate( | ||
| model_name_or_path: str, | ||
| base_model: str, | ||
| eval_file: str, | ||
| context_length: int, | ||
|
|
@@ -275,7 +274,14 @@ def validate( | |
| assignment_id: str = None, | ||
| local_test: bool = False, | ||
| lora_only: bool = True, | ||
| hg_repo_id: str = None, | ||
| revision: str = "main", | ||
| access_key: str = None, | ||
| secret_key: str = None, | ||
| endpoint_url: str = None, | ||
| bucket: str = None, | ||
| session_token: str = None, | ||
| prefix: str = None, | ||
| ): | ||
| if not local_test and assignment_id is None: | ||
| raise ValueError( | ||
|
|
@@ -291,11 +297,36 @@ def validate( | |
| val_args = parser.parse_json_file(json_file=validation_args_file)[0] | ||
| gpu_type = get_gpu_type() | ||
|
|
||
| tokenizer = load_tokenizer(model_name_or_path) | ||
| eval_dataset = load_sft_dataset( | ||
| eval_file, context_length, template_name=base_model, tokenizer=tokenizer | ||
| ) | ||
| model = load_model(model_name_or_path, lora_only, revision, val_args) | ||
| if hg_repo_id is None: | ||
| cf_storage = CloudStorage( | ||
| access_key=access_key, | ||
| secret_key=secret_key, | ||
| endpoint_url=endpoint_url, | ||
| bucket=bucket, | ||
| session_token=session_token, | ||
| ) | ||
| cf_storage.initialize() | ||
| cf_download_result = cf_storage.download_files( | ||
| prefix=prefix, local_dir="lora" | ||
| ) | ||
| if not cf_download_result: | ||
| fed_ledger.mark_assignment_as_failed(assignment_id) | ||
| return | ||
| lora_model_path = os.path.join("lora", prefix) | ||
| tokenizer = load_tokenizer( | ||
| model_name_or_path=lora_model_path, base_model=base_model | ||
| ) | ||
| eval_dataset = load_sft_dataset( | ||
| eval_file, context_length, template_name=base_model, tokenizer=tokenizer | ||
| ) | ||
| model = load_model(lora_model_path, lora_only, val_args) | ||
| else: | ||
| tokenizer = load_tokenizer(hg_repo_id, base_model=base_model) | ||
| eval_dataset = load_sft_dataset( | ||
| eval_file, context_length, template_name=base_model, tokenizer=tokenizer | ||
| ) | ||
| download_lora_repo(hg_repo_id, revision) | ||
| model = load_model("lora", lora_only, val_args) | ||
| # if model is not loaded, mark the assignment as failed and return | ||
| if model is None: | ||
| fed_ledger.mark_assignment_as_failed(assignment_id) | ||
|
|
@@ -458,25 +489,44 @@ def loop( | |
| continue | ||
| resp = resp.json() | ||
| eval_file = download_file(resp["data"]["validation_set_url"]) | ||
| revision = resp["task_submission"]["data"].get("revision", "main") | ||
| assignment_id = resp["id"] | ||
|
|
||
| for attempt in range(3): | ||
| try: | ||
| ctx = click.Context(validate) | ||
| ctx.invoke( | ||
| validate, | ||
| model_name_or_path=resp["task_submission"]["data"]["hg_repo_id"], | ||
| base_model=resp["data"]["base_model"], | ||
| eval_file=eval_file, | ||
| context_length=resp["data"]["context_length"], | ||
| max_params=resp["data"]["max_params"], | ||
| validation_args_file=validation_args_file, | ||
| assignment_id=resp["id"], | ||
| local_test=False, | ||
| lora_only=lora_only, | ||
| revision=revision, | ||
| ) | ||
| if "hg_repo_id" in resp["task_submission"]["data"]: | ||
| revision = resp["task_submission"]["data"].get("revision", "main") | ||
| ctx.invoke( | ||
| validate, | ||
| hg_repo_id=resp["task_submission"]["data"]["hg_repo_id"], | ||
| base_model=resp["data"]["base_model"], | ||
| eval_file=eval_file, | ||
| context_length=resp["data"]["context_length"], | ||
| max_params=resp["data"]["max_params"], | ||
| validation_args_file=validation_args_file, | ||
| assignment_id=resp["id"], | ||
| local_test=False, | ||
| lora_only=lora_only, | ||
| revision=revision, | ||
| ) | ||
| else: | ||
| ctx.invoke( | ||
| validate, | ||
| base_model=resp["data"]["base_model"], | ||
| eval_file=eval_file, | ||
| context_length=resp["data"]["context_length"], | ||
| max_params=resp["data"]["max_params"], | ||
| validation_args_file=validation_args_file, | ||
| assignment_id=resp["id"], | ||
| local_test=False, | ||
| lora_only=lora_only, | ||
| access_key=resp["data"]["access_key"], | ||
| secret_key=resp["data"]["secret_key"], | ||
| endpoint_url=resp["data"]["endpoint_url"], | ||
| bucket=resp["data"]["bucket"], | ||
| session_token=resp["data"]["session_token"], | ||
| prefix=resp["data"]["prefix"], | ||
| ) | ||
| break # Break the loop if no exception | ||
| except KeyboardInterrupt: | ||
| # directly terminate the process if keyboard interrupt | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Update error message to reflect required parameters
The error message mentions
session_tokenas a required parameter, butsession_tokenis optional for AWS credentials. Also, the current parameter check does not verify ifsession_tokenis provided. Update the error message to accurately reflect the required parameters.Apply this diff to correct the error message:
📝 Committable suggestion