Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/axolotl/cli/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,16 @@ def do_cli_lm_eval(
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")

with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)
20 changes: 20 additions & 0 deletions src/axolotl/cli/cloud/baseten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,23 @@ def train(
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)

def lm_eval(
self,
config_yaml: str,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/eval.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(
dirname(__file__) + "/template/eval.sh", tmp_dir + "/eval.sh"
)
shutil.copyfile(
dirname(__file__) + "/template/eval_sft.py", tmp_dir + "/eval_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "eval_sft.py"], cwd=tmp_dir, check=False
)
8 changes: 8 additions & 0 deletions src/axolotl/cli/cloud/baseten/template/eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
set -eux

export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000

axolotl lm-eval eval.yaml
81 changes: 81 additions & 0 deletions src/axolotl/cli/cloud/baseten/template/eval_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Baseten Training Script for Axolotl
"""

# pylint: skip-file
import yaml
from truss.base import truss_config

# Import necessary classes from the Baseten Training SDK
from truss_train import definitions

cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = (
1 # int(cloud_config.get("gpu_count", 1)) # only single GPU supported at the moment
)
node_count = (
1 # int(cloud_config.get("node_count", 1)) # only single node support for lmeval
)
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
# launcher = cloud_config.get("launcher", "accelerate")
# launcher_args = cloud_config.get("launcher_args", [])
script_name = "eval.sh"

# launcher_args_str = ""
# if launcher_args:
# launcher_args_str = "-- " + " ".join(launcher_args)

# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"

# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.

env_vars = {
# "AXOLOTL_LAUNCHER": launcher,
# "AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)

training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)

# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)

# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)


# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)
6 changes: 6 additions & 0 deletions src/axolotl/cli/cloud/baseten/template/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)

# 3. Define the Compute Resources for the Training Job
Expand Down