diff --git a/requirements.txt b/requirements.txt index 2dca12b..03f1f2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ python-dotenv peft>=0.10.0 gitpython pre-commit +boto3 diff --git a/src/core/cloudflare_utils.py b/src/core/cloudflare_utils.py new file mode 100644 index 0000000..b2769c9 --- /dev/null +++ b/src/core/cloudflare_utils.py @@ -0,0 +1,56 @@ +import os +import boto3 +from loguru import logger + + +class CloudStorage: + def __init__( + self, + access_key=None, + secret_key=None, + endpoint_url=None, + bucket=None, + session_token=None, + ): + self.access_key = access_key + self.secret_key = secret_key + self.endpoint_url = endpoint_url + self.bucket = bucket + self.client = None + self.session_token = session_token + + def initialize(self): + if ( + self.access_key is None + or self.secret_key is None + or self.endpoint_url is None + ): + logger.error( + "Please provide access_key, secret_key, session_token and endpoint_url" + ) + raise + self.client = boto3.client( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + aws_session_token=self.session_token, + ) + return self + + def download_files(self, prefix: str, local_dir: str) -> bool: + response = self.client.list_objects_v2(Prefix=prefix, Bucket=self.bucket) + + if "Contents" in response: + for obj in response["Contents"]: + file_key = obj["Key"] + local_file_path = os.path.join(local_dir, file_key) + + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + + self.client.download_file(self.bucket, file_key, local_file_path) + logger.info(f"Downloaded {file_key} to {local_file_path}") + return True + else: + logger.info("No files found in the bucket.") + return False diff --git a/src/validate.py b/src/validate.py index d69b80b..5e4c8d8 100644 --- a/src/validate.py +++ b/src/validate.py @@ -23,7 +23,7 @@ from core.collator import SFTDataCollator from core.dataset import UnifiedSFTDataset from core.template import template_dict -from core.hf_utils import download_lora_config, download_lora_repo +from core.hf_utils import download_lora_repo from core.gpu_utils import get_gpu_type from core.constant import SUPPORTED_BASE_MODELS from core.exception import ( @@ -34,6 +34,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential from client.fed_ledger import FedLedger from peft import PeftModel +from core.cloudflare_utils import CloudStorage import sys load_dotenv() @@ -90,12 +91,12 @@ def download_file(url): raise e -def load_tokenizer(model_name_or_path: str) -> AutoTokenizer: +def load_tokenizer(model_name_or_path: str, base_model: str) -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, use_fast=True, ) - if "gemma" in model_name_or_path.lower(): + if "gemma" in base_model.lower(): tokenizer.add_special_tokens( {"additional_special_tokens": ["", ""]} ) @@ -113,9 +114,9 @@ def load_tokenizer(model_name_or_path: str) -> AutoTokenizer: def load_model( - model_name_or_path: str, lora_only: bool, revision: str, val_args: TrainingArguments + model_path: str, lora_only: bool, val_args: TrainingArguments ) -> Trainer: - logger.info(f"Loading model from base model: {model_name_or_path}") + logger.info(f"Loading model from base model: {model_path}") if val_args.use_cpu: torch_dtype = torch.float32 @@ -128,19 +129,19 @@ def load_model( device_map=None, ) # check whether it is a lora weight - if download_lora_config(model_name_or_path, revision): + + if os.path.isfile(os.path.join(model_path, "adapter_config.json")): logger.info("Repo is a lora weight, loading model with adapter weights") - with open("lora/adapter_config.json", "r") as f: + with open(os.path.join(model_path, "adapter_config.json"), "r") as f: adapter_config = json.load(f) base_model = adapter_config["base_model_name_or_path"] model = AutoModelForCausalLM.from_pretrained( base_model, token=HF_TOKEN, **model_kwargs ) - # download the adapter weights - download_lora_repo(model_name_or_path, revision) + model = PeftModel.from_pretrained( model, - "lora", + model_path, device_map=None, ) model = model.merge_and_unload() @@ -154,7 +155,7 @@ def load_model( return None logger.info("Repo is a full fine-tuned model, loading model directly") model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, token=HF_TOKEN, **model_kwargs + model_path, token=HF_TOKEN, **model_kwargs ) if "output_router_logits" in model.config.to_dict(): @@ -244,7 +245,6 @@ def cli(): @click.command() -@click.option("--model_name_or_path", required=True, type=str, help="") @click.option("--base_model", required=True, type=str, help="") @click.option("--eval_file", default="./data/dummy_data.jsonl", type=str, help="") @click.option("--context_length", required=True, type=int) @@ -266,7 +266,6 @@ def cli(): help="Run the script in local test mode to avoid submitting to the server", ) def validate( - model_name_or_path: str, base_model: str, eval_file: str, context_length: int, @@ -275,7 +274,14 @@ def validate( assignment_id: str = None, local_test: bool = False, lora_only: bool = True, + hg_repo_id: str = None, revision: str = "main", + access_key: str = None, + secret_key: str = None, + endpoint_url: str = None, + bucket: str = None, + session_token: str = None, + prefix: str = None, ): if not local_test and assignment_id is None: raise ValueError( @@ -291,11 +297,36 @@ def validate( val_args = parser.parse_json_file(json_file=validation_args_file)[0] gpu_type = get_gpu_type() - tokenizer = load_tokenizer(model_name_or_path) - eval_dataset = load_sft_dataset( - eval_file, context_length, template_name=base_model, tokenizer=tokenizer - ) - model = load_model(model_name_or_path, lora_only, revision, val_args) + if hg_repo_id is None: + cf_storage = CloudStorage( + access_key=access_key, + secret_key=secret_key, + endpoint_url=endpoint_url, + bucket=bucket, + session_token=session_token, + ) + cf_storage.initialize() + cf_download_result = cf_storage.download_files( + prefix=prefix, local_dir="lora" + ) + if not cf_download_result: + fed_ledger.mark_assignment_as_failed(assignment_id) + return + lora_model_path = os.path.join("lora", prefix) + tokenizer = load_tokenizer( + model_name_or_path=lora_model_path, base_model=base_model + ) + eval_dataset = load_sft_dataset( + eval_file, context_length, template_name=base_model, tokenizer=tokenizer + ) + model = load_model(lora_model_path, lora_only, val_args) + else: + tokenizer = load_tokenizer(hg_repo_id, base_model=base_model) + eval_dataset = load_sft_dataset( + eval_file, context_length, template_name=base_model, tokenizer=tokenizer + ) + download_lora_repo(hg_repo_id, revision) + model = load_model("lora", lora_only, val_args) # if model is not loaded, mark the assignment as failed and return if model is None: fed_ledger.mark_assignment_as_failed(assignment_id) @@ -458,25 +489,44 @@ def loop( continue resp = resp.json() eval_file = download_file(resp["data"]["validation_set_url"]) - revision = resp["task_submission"]["data"].get("revision", "main") assignment_id = resp["id"] for attempt in range(3): try: ctx = click.Context(validate) - ctx.invoke( - validate, - model_name_or_path=resp["task_submission"]["data"]["hg_repo_id"], - base_model=resp["data"]["base_model"], - eval_file=eval_file, - context_length=resp["data"]["context_length"], - max_params=resp["data"]["max_params"], - validation_args_file=validation_args_file, - assignment_id=resp["id"], - local_test=False, - lora_only=lora_only, - revision=revision, - ) + if "hg_repo_id" in resp["task_submission"]["data"]: + revision = resp["task_submission"]["data"].get("revision", "main") + ctx.invoke( + validate, + hg_repo_id=resp["task_submission"]["data"]["hg_repo_id"], + base_model=resp["data"]["base_model"], + eval_file=eval_file, + context_length=resp["data"]["context_length"], + max_params=resp["data"]["max_params"], + validation_args_file=validation_args_file, + assignment_id=resp["id"], + local_test=False, + lora_only=lora_only, + revision=revision, + ) + else: + ctx.invoke( + validate, + base_model=resp["data"]["base_model"], + eval_file=eval_file, + context_length=resp["data"]["context_length"], + max_params=resp["data"]["max_params"], + validation_args_file=validation_args_file, + assignment_id=resp["id"], + local_test=False, + lora_only=lora_only, + access_key=resp["data"]["access_key"], + secret_key=resp["data"]["secret_key"], + endpoint_url=resp["data"]["endpoint_url"], + bucket=resp["data"]["bucket"], + session_token=resp["data"]["session_token"], + prefix=resp["data"]["prefix"], + ) break # Break the loop if no exception except KeyboardInterrupt: # directly terminate the process if keyboard interrupt