diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index 60f6a51ce4..262e6a8adb 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -67,8 +67,16 @@ def do_cli_lm_eval( cloud_config: Path | str, config: Path | str, ) -> None: - cloud_cfg = load_cloud_cfg(cloud_config) - cloud = ModalCloud(cloud_cfg) + cloud_cfg: DictDefault = load_cloud_cfg(cloud_config) + provider = cloud_cfg.provider or "modal" + cloud: Cloud | None + if provider == "modal": + cloud = ModalCloud(cloud_cfg) + elif provider == "baseten": + cloud = BasetenCloud(cloud_cfg.to_dict()) + else: + raise ValueError(f"Unsupported cloud provider: {provider}") + with open(config, "r", encoding="utf-8") as file: config_yaml = file.read() cloud.lm_eval(config_yaml) diff --git a/src/axolotl/cli/cloud/baseten/__init__.py b/src/axolotl/cli/cloud/baseten/__init__.py index 914504de3a..17195e07be 100644 --- a/src/axolotl/cli/cloud/baseten/__init__.py +++ b/src/axolotl/cli/cloud/baseten/__init__.py @@ -46,3 +46,23 @@ def train( subprocess.run( # nosec B603 B607 ["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False ) + + def lm_eval( + self, + config_yaml: str, + ): + with tempfile.TemporaryDirectory() as tmp_dir: + config = self.config.copy() + with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout: + yaml.dump(config, cloud_fout) + with open(tmp_dir + "/eval.yaml", "w", encoding="utf-8") as config_fout: + config_fout.write(config_yaml) + shutil.copyfile( + dirname(__file__) + "/template/eval.sh", tmp_dir + "/eval.sh" + ) + shutil.copyfile( + dirname(__file__) + "/template/eval_sft.py", tmp_dir + "/eval_sft.py" + ) + subprocess.run( # nosec B603 B607 + ["truss", "train", "push", "eval_sft.py"], cwd=tmp_dir, check=False + ) diff --git a/src/axolotl/cli/cloud/baseten/template/eval.sh b/src/axolotl/cli/cloud/baseten/template/eval.sh new file mode 100644 index 0000000000..e89afcc57e --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/template/eval.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -eux + +export NCCL_SOCKET_IFNAME="^docker0,lo" +export NCCL_IB_DISABLE=0 +export NCCL_TIMEOUT=1800000 + +axolotl lm-eval eval.yaml diff --git a/src/axolotl/cli/cloud/baseten/template/eval_sft.py b/src/axolotl/cli/cloud/baseten/template/eval_sft.py new file mode 100644 index 0000000000..a2d37b4663 --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/template/eval_sft.py @@ -0,0 +1,81 @@ +""" +Baseten Training Script for Axolotl +""" + +# pylint: skip-file +import yaml +from truss.base import truss_config + +# Import necessary classes from the Baseten Training SDK +from truss_train import definitions + +cloud_config = yaml.safe_load(open("cloud.yaml", "r")) +gpu = cloud_config.get("gpu", "h100") +gpu_count = ( + 1 # int(cloud_config.get("gpu_count", 1)) # only single GPU supported at the moment +) +node_count = ( + 1 # int(cloud_config.get("node_count", 1)) # only single node support for lmeval +) +project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project" +secrets = cloud_config.get("secrets", []) +# launcher = cloud_config.get("launcher", "accelerate") +# launcher_args = cloud_config.get("launcher_args", []) +script_name = "eval.sh" + +# launcher_args_str = "" +# if launcher_args: +# launcher_args_str = "-- " + " ".join(launcher_args) + +# 1. Define a base image for your training job +# must use torch 2.7.0 for vllm +BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1" + +# 2. Define the Runtime Environment for the Training Job +# This includes start commands and environment variables.a +# Secrets from the baseten workspace like API keys are referenced using +# `SecretReference`. + +env_vars = { + # "AXOLOTL_LAUNCHER": launcher, + # "AXOLOTL_LAUNCHER_ARGS": launcher_args_str, +} +for secret_name in secrets: + env_vars[secret_name] = definitions.SecretReference(name=secret_name) + +training_runtime = definitions.Runtime( + start_commands=[ # Example: list of commands to run your training script + f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'" + ], + environment_variables=env_vars, + cache_config=definitions.CacheConfig( + enabled=True, + ), + checkpointing_config=definitions.CheckpointingConfig( + enabled=True, + ), +) + +# 3. Define the Compute Resources for the Training Job +training_compute = definitions.Compute( + node_count=node_count, + accelerator=truss_config.AcceleratorSpec( + accelerator=truss_config.Accelerator.H100, + count=gpu_count, + ), +) + +# 4. Define the Training Job +# This brings together the image, compute, and runtime configurations. +my_training_job = definitions.TrainingJob( + image=definitions.Image(base_image=BASE_IMAGE), + compute=training_compute, + runtime=training_runtime, +) + + +# This config will be pushed using the Truss CLI. +# The association of the job to the project happens at the time of push. +first_project_with_job = definitions.TrainingProject( + name=project_name, job=my_training_job +) diff --git a/src/axolotl/cli/cloud/baseten/template/train_sft.py b/src/axolotl/cli/cloud/baseten/template/train_sft.py index 137fb91714..6cf491b8f8 100644 --- a/src/axolotl/cli/cloud/baseten/template/train_sft.py +++ b/src/axolotl/cli/cloud/baseten/template/train_sft.py @@ -44,6 +44,12 @@ f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'" ], environment_variables=env_vars, + cache_config=definitions.CacheConfig( + enabled=True, + ), + checkpointing_config=definitions.CheckpointingConfig( + enabled=True, + ), ) # 3. Define the Compute Resources for the Training Job