Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 59 additions & 49 deletions examples/kfto-sft-llm/sft.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
"# Model\n",
"model_name_or_path: Meta-Llama/Meta-Llama-3.1-8B-Instruct\n",
"model_revision: main\n",
"torch_dtype: bfloat16\n",
"dtype: bfloat16\n",
"attn_implementation: flash_attention_2 # one of eager (default), sdpa or flash_attention_2\n",
"use_liger: false # use Liger kernels\n",
"use_liger_kernel: false # use Liger kernels\n",
"\n",
"# PEFT / LoRA\n",
"use_peft: true\n",
Expand All @@ -69,9 +69,9 @@
" append_concat_token: false # add additional separator token\n",
"\n",
"# SFT\n",
"max_seq_length: 1024 # max sequence length for model and packing of the dataset\n",
"dataset_batch_size: 1000 # samples to tokenize per batch\n",
"max_length: 1024 # max sequence length for model and packing of the dataset\n",
"packing: false\n",
"padding_free: false\n",
"\n",
"# Training\n",
"num_train_epochs: 10 # number of training epochs\n",
Expand Down Expand Up @@ -175,7 +175,7 @@
" revision=model_args.model_revision,\n",
" trust_remote_code=model_args.trust_remote_code,\n",
" attn_implementation=model_args.attn_implementation,\n",
" torch_dtype=model_args.torch_dtype,\n",
" dtype=model_args.dtype,\n",
" use_cache=False if training_args.gradient_checkpointing or\n",
" training_args.fsdp_config.get(\"activation_checkpointing\",\n",
" False) else True,\n",
Expand Down Expand Up @@ -294,8 +294,8 @@
"metadata": {},
"outputs": [],
"source": [
"# IMPORTANT: Labels and annotations support in create_job() method requires kubeflow-training v1.9.2+. Skip this cell if using RHOAI 2.21 or later.\n",
"%pip install -U kubeflow-training "
"%pip install --ignore-installed git+https://github.com/astefanutti/sdk.git@options\n",
"%pip install --ignore-installed git+https://github.com/kubeflow/trainer.git@master#subdirectory=api/python_api"
]
},
{
Expand All @@ -305,9 +305,9 @@
"metadata": {},
"outputs": [],
"source": [
"from kubeflow.trainer import TrainerClient, CustomTrainer, KubernetesBackendConfig\n",
"from kubeflow.trainer.options import PodTemplateOverrides, PodTemplateOverride, PodTemplateSpecOverride, ContainerOverride, Labels\n",
"from kubernetes import client\n",
"from kubeflow.training import TrainingClient\n",
"from kubeflow.training.models import V1Volume, V1VolumeMount, V1PersistentVolumeClaimVolumeSource\n",
"\n",
"api_server = \"https://kubernetes.default.svc\"\n",
"token = \"<TOKEN>\"\n",
Expand All @@ -318,7 +318,7 @@
"# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA\n",
"#configuration.verify_ssl = False\n",
"api_client = client.ApiClient(configuration)\n",
"client = TrainingClient(client_configuration=api_client.configuration)"
"client = TrainerClient(backend_config=KubernetesBackendConfig(client_configuration=api_client.configuration))"
]
},
{
Expand All @@ -333,8 +333,7 @@
"* Check the number of worker nodes\n",
"* Amend the resources per worker according to the job requirements\n",
"* If you use AMD accelerators:\n",
" * Change `nvidia.com/gpu` to `amd.com/gpu` in `resources_per_worker`\n",
" * Change `base_image` to `quay.io/modh/training:py311-rocm62-torch251`\n",
" * Change `runtime` to `training-rocm64-torch28-py312`\n",
"* Update the PVC name to the one you've attached to the workbench if needed"
]
},
Expand All @@ -345,38 +344,52 @@
"metadata": {},
"outputs": [],
"source": [
"client.create_job(\n",
" job_kind=\"PyTorchJob\",\n",
" name=\"sft\",\n",
" train_func=main,\n",
" num_workers=8,\n",
" num_procs_per_worker=\"1\",\n",
" resources_per_worker={\n",
" \"nvidia.com/gpu\": 1,\n",
" \"memory\": \"64Gi\",\n",
" \"cpu\": 4,\n",
" },\n",
" base_image=\"quay.io/modh/training:py311-cuda124-torch251\",\n",
" env_vars={\n",
" # HuggingFace\n",
" \"HF_HOME\": \"/mnt/shared/.cache\",\n",
" \"HF_TOKEN\": \"\",\n",
" # CUDA / ROCm (HIP)\n",
" \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\",\n",
" \"PYTORCH_HIP_ALLOC_CONF\": \"expandable_segments:True\",\n",
" # NCCL / RCCL\n",
" \"NCCL_DEBUG\": \"INFO\",\n",
" },\n",
" # labels={\"kueue.x-k8s.io/queue-name\": \"<LOCAL_QUEUE_NAME>\"}, # Optional: Add local queue name and uncomment these lines if using Kueue for resource management\n",
" parameters=parameters,\n",
" volumes=[\n",
" V1Volume(name=\"shared\",\n",
" persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name=\"shared\")),\n",
"job_name = client.train(\n",
" trainer=CustomTrainer(\n",
" func=main,\n",
" func_args=parameters,\n",
" num_nodes=4,\n",
" resources_per_node={\n",
" \"cpu\": 4,\n",
" \"memory\": \"128Gi\",\n",
" \"gpu\": 1,\n",
" },\n",
" env={\n",
" # HuggingFace\n",
" \"HF_HOME\": \"/mnt/shared/.cache\",\n",
" \"HF_TOKEN\": \"\",\n",
" # CUDA / ROCm (HIP)\n",
" \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\",\n",
" #\"PYTORCH_HIP_ALLOC_CONF\": \"expandable_segments:True\",\n",
" # NCCL / RCCL\n",
" \"NCCL_DEBUG\": \"INFO\",\n",
" },\n",
" packages_to_install=[\n",
" \"transformers==4.57.1\",\n",
" \"trl==0.24.0\"\n",
" ],\n",
" ),\n",
" options=[\n",
" PodTemplateOverrides(pod_template_overrides=[\n",
" PodTemplateOverride(\n",
" target_jobs=[\"node\"],\n",
" spec=PodTemplateSpecOverride(\n",
" volumes=[\n",
" {\"name\": \"shared\", \"persistentVolumeClaim\": {\"claimName\": \"shared\"}}\n",
" ],\n",
" containers=[\n",
" ContainerOverride(\n",
" name=\"node\",\n",
" volume_mounts=[\n",
" {\"name\": \"shared\", \"mountPath\": \"/mnt/shared\"},\n",
" ]\n",
" )\n",
" ],\n",
" )\n",
" )\n",
" ])\n",
" ],\n",
" volume_mounts=[\n",
" V1VolumeMount(name=\"shared\", mount_path=\"/mnt/shared\"),\n",
" ],\n",
")"
" runtime=client.get_runtime(\"training-cuda128-torch28-py312\"),"
]
},
{
Expand All @@ -394,11 +407,8 @@
"metadata": {},
"outputs": [],
"source": [
"_ = client.get_job_logs(\n",
" name=\"sft\",\n",
" job_kind=\"PyTorchJob\",\n",
" follow=True,\n",
")"
"for logline in client.get_job_logs(job_name, follow=True):\n",
" print(logline)\""
]
},
{
Expand Down Expand Up @@ -650,7 +660,7 @@
"metadata": {},
"outputs": [],
"source": [
"client.delete_job(name=\"sft\")"
"client.delete_job(job_name)"
]
},
{
Expand Down