diff --git a/examples/kft-v2/README.md b/examples/kft-v2/README.md new file mode 100644 index 000000000..28e842846 --- /dev/null +++ b/examples/kft-v2/README.md @@ -0,0 +1,190 @@ +# 🚀 Kubeflow Training V2: Advanced ML Training with Distributed Computing + +This directory contains comprehensive examples demonstrating **Kubeflow Training V2** capabilities for distributed training using the Kubeflow Trainer SDK. + +## 🎯 **What This Directory Demonstrates** + +- **Kubeflow Trainer SDK**: Programmatic TrainJob creation and management +- **Checkpointing**: Controller-managed resume/suspended compatibility for model checkpoints +- **Distributed Training**: Multi-node Multi-CPU/GPU coordination with NCCL/GLOO backends + +--- +### **TRL (Transformer Reinforcement Learning) Integration** +- **SFTTrainer**: Supervised fine-tuning with instruction following +- **PEFT-LoRA**: Parameter-efficient fine-tuning with Low-Rank Adaptation +- **Model Support**: GPT-2, Llama, and other transformer models +- **Dataset Integration**: Alpaca dataset for instruction-following tasks + +### **Distributed Training Capabilities** +- **Multi-Node Support**: Scale training across multiple nodes +- **Multi-GPU Coordination**: NCCL backend CUDA for NVIDIA GPUs, ROCm for AMD GPUs +- **CPU Training**: GLOO backend for CPU-based training +- **Resource Flexibility**: Configurable compute resources per node + +--- + +## 📋 **Prerequisites** + +### **Cluster Requirements** +- **OpenShift Cluster**: With OpenShift AI (RHOAI) 2.17+ installed +- **Required Components**: `dashboard`, `trainingoperator`, and `workbenches` enabled +- **Storage**: Persistent volume claim named `workspace` of minimum 50GB with RWX (ReadWriteMany) access mode + +--- + +## 🛠️ **Setup Instructions** + +### **1. Repository Setup** + +Clone the repository and navigate to the kft-v2 directory: + +```bash +git clone https://github.com/opendatahub-io/distributed-workloads.git +cd distributed-workloads/examples/kft-v2 +``` + +### **2. Persistent Volume Setup** + +Create a shared persistent volume for checkpoint storage: + +```bash +oc apply -f manifests/shared_pvc.yaml +``` + +### **3. Cluster Training Runtime Setup** + +Apply the cluster training runtime configuration: + +```bash +oc apply -f manifests/cluster_training_runtime.yaml +``` + +This creates the necessary ClusterTrainingRuntime resources for PyTorch training. + + +## Setup + +* Access the OpenShift AI dashboard, for example from the top navigation bar menu: + +![](./docs/01.png) + +* Log in, then go to _Data Science Projects_ and create a project: + +![](./docs/02.png) + +* Once the project is created, click on _Create a workbench_: + +![](./docs/03.png) + +* Then create a workbench with the following settings: + + * Select the `PyTorch` (or the `ROCm-PyTorch`) notebook image: + + * Select the _Medium_ container size and a sufficient persistent storage volume. + + ![](./docs/04.png) + + ![](./docs/05.png) + + > [!NOTE] + > + > * Adding an accelerator is only needed to test the fine-tuned model from within the workbench so you can spare an accelerator if needed. + > * Keep the default 20GB workbench storage, it is enough to run the inference from within the workbench. + + + * Review the configuration and click _Create workbench_ + +* From "Workbenches" page, click on _Open_ when the workbench you've just created becomes ready: + +![](./docs/06.png) + +--- + +## 🚀 **Quick Start Examples** + +### **Example 1: Fashion-MNIST Training** + +Run the Fashion-MNIST training example: + +```python +from scripts.mnist import train_fashion_mnist + +# Configure training parameters +config = { + "epochs": 10, + "batch_size": 64, + "learning_rate": 0.001, + "checkpoint_dir": "/mnt/shared/checkpoints" +} + +# Start training +train_fashion_mnist(config) +``` + +### **Example 2: TRL GPT-2 Fine-tuning** + +Run the TRL training example: + +```python +from scripts.trl_training import trl_train + +# Configure TRL parameters +config = { + "model_name": "gpt2", + "dataset_name": "alpaca", + "lora_r": 16, + "lora_alpha": 32, + "max_seq_length": 512 +} + +# Start TRL training +trl_train(config) +``` + +![](./docs/07.png) + +--- + +## 📊 **Training Examples** + +### **Fashion-MNIST Classification** + +The `mnist.py` script demonstrates: + +- **Distributed Training**: Multi-GPU Fashion-MNIST classification +- **Checkpointing**: Automatic checkpoint creation and resumption +- **Progress Tracking**: Real-time training progress monitoring +- **Error Handling**: Robust error handling and recovery + +**Key Features:** +- CNN architecture for image classification +- Distributed data loading with DistributedSampler +- Automatic mixed precision (AMP) training +- Comprehensive logging and metrics + +### **TRL GPT-2 Fine-tuning** + +The `trl_training.py` script demonstrates: + +- **Instruction Following**: Fine-tuning GPT-2 on Alpaca dataset +- **PEFT-LoRA**: Parameter-efficient fine-tuning +- **Checkpoint Management**: TRL-compatible checkpointing +- **Distributed Coordination**: Multi-node training coordination + +**Key Features:** +- SFTTrainer for supervised fine-tuning +- LoRA adapters for efficient parameter updates +- Instruction-following dataset processing +- Hugging Face model integration + +--- + +## 📚 **References and Documentation** + +- **[Kubeflow Trainer SDK](https://github.com/kubeflow/sdk)**: Official SDK documentation +- **[TRL Documentation](https://huggingface.co/docs/trl/)**: Transformer Reinforcement Learning +- **[PEFT Documentation](https://huggingface.co/docs/peft/)**: Parameter-Efficient Fine-Tuning +- **[PyTorch Distributed Training](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)**: Distributed training guide +- **[OpenShift AI Documentation](https://access.redhat.com/documentation/en-us/red_hat_openshift_ai)**: RHOAI documentation + +--- diff --git a/examples/kft-v2/docs/01.png b/examples/kft-v2/docs/01.png new file mode 100644 index 000000000..ede6f88ca Binary files /dev/null and b/examples/kft-v2/docs/01.png differ diff --git a/examples/kft-v2/docs/02.png b/examples/kft-v2/docs/02.png new file mode 100644 index 000000000..b747d14c1 Binary files /dev/null and b/examples/kft-v2/docs/02.png differ diff --git a/examples/kft-v2/docs/03.png b/examples/kft-v2/docs/03.png new file mode 100644 index 000000000..5a08ee1ba Binary files /dev/null and b/examples/kft-v2/docs/03.png differ diff --git a/examples/kft-v2/docs/04.png b/examples/kft-v2/docs/04.png new file mode 100644 index 000000000..09e9ddd0d Binary files /dev/null and b/examples/kft-v2/docs/04.png differ diff --git a/examples/kft-v2/docs/05.png b/examples/kft-v2/docs/05.png new file mode 100644 index 000000000..cfed07ebc Binary files /dev/null and b/examples/kft-v2/docs/05.png differ diff --git a/examples/kft-v2/docs/06.png b/examples/kft-v2/docs/06.png new file mode 100644 index 000000000..a3fd39b46 Binary files /dev/null and b/examples/kft-v2/docs/06.png differ diff --git a/examples/kft-v2/docs/07.png b/examples/kft-v2/docs/07.png new file mode 100644 index 000000000..5a7dc2ac9 Binary files /dev/null and b/examples/kft-v2/docs/07.png differ diff --git a/examples/kft-v2/docs/jobs.png b/examples/kft-v2/docs/jobs.png new file mode 100644 index 000000000..355d534d0 Binary files /dev/null and b/examples/kft-v2/docs/jobs.png differ diff --git a/examples/kft-v2/docs/trainjob_pods.png b/examples/kft-v2/docs/trainjob_pods.png new file mode 100644 index 000000000..9a825ed5b Binary files /dev/null and b/examples/kft-v2/docs/trainjob_pods.png differ diff --git a/examples/kft-v2/docs/trainjobs_jobsets.png b/examples/kft-v2/docs/trainjobs_jobsets.png new file mode 100644 index 000000000..240e26091 Binary files /dev/null and b/examples/kft-v2/docs/trainjobs_jobsets.png differ diff --git a/examples/kft-v2/manifests/cluster_training_runtime.yaml b/examples/kft-v2/manifests/cluster_training_runtime.yaml new file mode 100644 index 000000000..d2b1863ad --- /dev/null +++ b/examples/kft-v2/manifests/cluster_training_runtime.yaml @@ -0,0 +1,141 @@ +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: torch-cuda-custom + labels: + trainer.kubeflow.org/framework: torch +spec: + mlPolicy: + numNodes: 2 + torch: + numProcPerNode: 1 + template: + metadata: {} + spec: + replicatedJobs: + - name: dataset-initializer + replicas: 1 + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: dataset-initializer + spec: + template: + spec: + containers: + - env: + - name: HF_HOME + value: /workspace/cache + - name: DATASET_NAME + value: tatsu-lab/alpaca + - name: DATASET_CONFIG + value: main + - name: DATASET_SPLIT + value: 'train[:500]' + - name: DATASET_FORMAT + value: json + - name: WORKSPACE_PATH + value: /workspace + image: 'ghcr.io/kubeflow/trainer/dataset-initializer:v2.0.0' + name: dataset-initializer + resources: + limits: + cpu: '2' + memory: 4Gi + requests: + cpu: '1' + memory: 2Gi + volumeMounts: + - mountPath: /workspace + name: workspace + restartPolicy: Never + volumes: + - name: workspace + persistentVolumeClaim: + claimName: workspace + - dependsOn: + - name: dataset-initializer + status: Complete + name: model-initializer + replicas: 1 + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: model-initializer + spec: + template: + spec: + containers: + - env: + - name: HF_HOME + value: /workspace/cache + - name: MODEL_NAME + value: gpt2 + - name: MODEL_REVISION + value: main + - name: DOWNLOAD_MODE + value: force_redownload + - name: WORKSPACE_PATH + value: /workspace + image: 'ghcr.io/kubeflow/trainer/model-initializer:v2.0.0' + name: model-initializer + resources: + limits: + cpu: '2' + memory: 4Gi + requests: + cpu: '1' + memory: 2Gi + volumeMounts: + - mountPath: /workspace + name: workspace + restartPolicy: Never + volumes: + - name: workspace + persistentVolumeClaim: + claimName: workspace + - dependsOn: + - name: model-initializer + status: Complete + name: node + replicas: 1 + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: trainer + spec: + template: + metadata: {} + spec: + containers: + - env: + - name: PYTHONUNBUFFERED + value: '1' + - name: NCCL_DEBUG + value: INFO + - name: NCCL_SOCKET_IFNAME + value: eth0 + - name: NCCL_IB_DISABLE + value: '1' + - name: NCCL_P2P_DISABLE + value: '1' + - name: TRAINJOB_PROGRESSION_FILE_PATH + value: /tmp/training_progression.json + - name: CHECKPOINT_DIR + value: /workspace/checkpoints + image: 'quay.io/modh/training:py311-cuda124-torch251' + name: node + resources: + limits: + cpu: '2' + memory: 4Gi + requests: + cpu: '1' + memory: 2Gi + volumeMounts: + - mountPath: /workspace + name: workspace + volumes: + - name: workspace + persistentVolumeClaim: + claimName: workspace \ No newline at end of file diff --git a/examples/kft-v2/manifests/shared_pvc.yaml b/examples/kft-v2/manifests/shared_pvc.yaml new file mode 100644 index 000000000..f5c1fb3c5 --- /dev/null +++ b/examples/kft-v2/manifests/shared_pvc.yaml @@ -0,0 +1,12 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: workspace +spec: + accessModes: + - ReadWriteMany + resources: + requests: + storage: 50Gi + storageClassName: nfs-csi + volumeMode: Filesystem \ No newline at end of file diff --git a/examples/kft-v2/scripts/mnist.py b/examples/kft-v2/scripts/mnist.py new file mode 100644 index 000000000..05aef792b --- /dev/null +++ b/examples/kft-v2/scripts/mnist.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +def train_fashion_mnist(): + """ + Fashion-MNIST PyTorch training script with progression tracking and checkpointing. + """ + + import json + import os + import time + from typing import Optional + from pathlib import Path + import glob + + import torch + import torch.distributed as dist + import torch.nn.functional as F + from torch import nn + from torch.utils.data import DataLoader, DistributedSampler + from torchvision import datasets, transforms + + class ProgressionTracker: + """Helper class to track and write training progression.""" + + def __init__( + self, + total_epochs: int, + steps_per_epoch: int, + status_file_path: Optional[str] = None, + update_interval: int = 30, + ): + """ + Initialize progression tracker. + + Args: + total_epochs: Total number of training epochs + steps_per_epoch: Number of steps per epoch + status_file_path: Path where progression status will be written. + If None, uses TRAINJOB_PROGRESSION_FILE_PATH env var or default. + update_interval: Minimum seconds between status updates + """ + self.total_epochs = total_epochs + self.steps_per_epoch = steps_per_epoch + self.total_steps = total_epochs * steps_per_epoch + self.status_file_path = status_file_path or os.getenv( + "TRAINJOB_PROGRESSION_FILE_PATH", "/tmp/training_progression.json" + ) + self.update_interval = update_interval + self.start_time = time.time() + self.last_update_time = 0 + self.current_epoch = 0 + self.current_step = 0 + self.metrics = {} + + def update_step( + self, + epoch: int, + step: int, + loss: float = None, + learning_rate: float = None, + checkpoint_dir: str = None, + **kwargs, + ): + """Update current step and optionally write status.""" + # Track cumulative progress across entire training lifecycle + # epoch is the absolute epoch number (1-based) + # step is the current batch within the epoch (0-based) + self.current_epoch = epoch + self.current_step = (epoch - 1) * self.steps_per_epoch + step + 1 + + # Separate optional structured training metrics and custom generic metrics + training_metrics = {} + generic_metrics = {} + + # Core training metrics + if loss is not None: + training_metrics["loss"] = str(loss) + if learning_rate is not None: + training_metrics["learning_rate"] = str(learning_rate) + + # Add checkpoint information if available + if checkpoint_dir and os.path.exists(checkpoint_dir): + try: + checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint-') or f.startswith('epoch-')] + if checkpoints: + training_metrics["checkpoints_stored"] = len(checkpoints) + # Find latest checkpoint by highest number + def get_checkpoint_number(checkpoint_name): + try: + # Handle both checkpoint-N and epoch-N formats + if 'checkpoint-' in checkpoint_name: + return int(checkpoint_name.split('-')[1].split('.')[0]) + elif 'epoch-' in checkpoint_name: + return int(checkpoint_name.split('-')[1].split('.')[0]) + else: + return -1 + except (IndexError, ValueError): + return -1 + + latest_checkpoint_name = max(checkpoints, key=get_checkpoint_number) + latest_checkpoint = os.path.join(checkpoint_dir, latest_checkpoint_name) + training_metrics["latest_checkpoint_path"] = latest_checkpoint + except (OSError, ValueError): + pass + + # Process additional metrics + for key, value in kwargs.items(): + str_value = str(value) + + # Map to structured TrainingMetrics fields + if key in ['accuracy', 'train_accuracy']: + training_metrics["accuracy"] = str_value + else: + # Everything else goes to generic metrics + generic_metrics[key] = str_value + + # Store metrics for status writing + self.training_metrics = training_metrics + self.generic_metrics = generic_metrics + + # Write status + current_time = time.time() + if current_time - self.last_update_time >= self.update_interval: + message = f"Training step {self.current_step}/{self.total_steps}" + self.write_status(message) + self.last_update_time = current_time + + def update_epoch(self, epoch: int, checkpoint_dir: str = None, **metrics): + """Update current epoch and write status.""" + self.current_epoch = epoch + + # Separate structured training metrics and generic metrics + training_metrics = {} + generic_metrics = {} + + # Process epoch metrics + for key, value in metrics.items(): + str_value = str(value) + + # Map to structured TrainingMetrics fields + if key in ['loss', 'avg_loss', 'train_loss']: + training_metrics["loss"] = str_value + elif key in ['accuracy', 'train_accuracy']: + training_metrics["accuracy"] = str_value + else: + # Everything else goes to generic metrics + generic_metrics[key] = str_value + + # Add checkpoint information if available + if checkpoint_dir and os.path.exists(checkpoint_dir): + try: + checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint-') or f.startswith('epoch-')] + if checkpoints: + training_metrics["checkpoints_stored"] = len(checkpoints) + # Find latest checkpoint by highest number + def get_checkpoint_number(checkpoint_name): + try: + if 'checkpoint-' in checkpoint_name: + return int(checkpoint_name.split('-')[1].split('.')[0]) + elif 'epoch-' in checkpoint_name: + return int(checkpoint_name.split('-')[1].split('.')[0]) + else: + return -1 + except (IndexError, ValueError): + return -1 + + latest_checkpoint_name = max(checkpoints, key=get_checkpoint_number) + latest_checkpoint = os.path.join(checkpoint_dir, latest_checkpoint_name) + training_metrics["latest_checkpoint_path"] = latest_checkpoint + except (OSError, ValueError): + pass + + # Store metrics for status writing + self.training_metrics = training_metrics + self.generic_metrics = generic_metrics + + epoch_num = epoch + 1 + total_epochs = self.total_epochs + message = f"Completed epoch {epoch_num}/{total_epochs}" + self.write_status(message) + + def write_status(self, message: str = "Training in progress"): + """Write current training status to file.""" + try: + current_time = time.time() + + # Basic status data + status_data = { + "message": message, + "timestamp": int(current_time), + "start_time": int(self.start_time), + "current_step": self.current_step, + "total_steps": self.total_steps, + "current_epoch": self.current_epoch, + "total_epochs": self.total_epochs, + } + + # Calculate percentage if we have step info + if self.total_steps > 0: + percentage = (self.current_step / self.total_steps) * 100 + status_data["percentage_complete"] = f"{percentage:.2f}" + + # Calculate ETA if we have progress + if self.current_step > 0: + elapsed_time = current_time - self.start_time + time_per_step = elapsed_time / self.current_step + remaining_steps = self.total_steps - self.current_step + eta_seconds = int(remaining_steps * time_per_step) + status_data["estimated_time_remaining"] = eta_seconds + + # Add structured training metrics if any + if hasattr(self, 'training_metrics') and self.training_metrics: + status_data["training_metrics"] = self.training_metrics + + # Add generic metrics if any + if hasattr(self, 'generic_metrics') and self.generic_metrics: + status_data["metrics"] = self.generic_metrics + + # Write to file atomically + temp_file = f"{self.status_file_path}.tmp" + with open(temp_file, "w") as f: + json.dump(status_data, f, indent=2) + os.rename(temp_file, self.status_file_path) + + except Exception as e: + print(f"Failed to write progression status: {e}") + + # Define the PyTorch CNN model to be trained + class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.Linear(4 * 4 * 50, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4 * 4 * 50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + def setup_distributed(): + """Initialize distributed training using operator-injected PET environment variables""" + # Use PET_* environment variables injected by the training operator + node_rank = int(os.getenv('PET_NODE_RANK', '0')) + num_nodes = int(os.getenv('PET_NNODES', '1')) + nproc_per_node = int(os.getenv('PET_NPROC_PER_NODE', '1')) + master_addr = os.getenv('PET_MASTER_ADDR', 'localhost') + master_port = os.getenv('PET_MASTER_PORT', '29500') + + # Calculate standard PyTorch distributed variables + local_rank = int(os.getenv('LOCAL_RANK', '0')) + world_size = num_nodes * nproc_per_node + global_rank = node_rank * nproc_per_node + local_rank + + # Set standard PyTorch environment variables for compatibility + os.environ['RANK'] = str(global_rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + + # Use NCCL if a GPU is available, otherwise use Gloo as communication backend. + device, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo") + print(f"Using Device: {device}, Backend: {backend}") + + # Initialize distributed training if world_size > 1 + if world_size > 1: + try: + torch.distributed.init_process_group( + backend=backend, + rank=global_rank, + world_size=world_size + ) + torch.distributed.barrier() + print( + "Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}".format( + world_size, global_rank, local_rank + ) + ) + except Exception as e: + print(f"Warning: Failed to initialize distributed training: {e}") + else: + print("Single node training - distributed not initialized") + + return local_rank, global_rank, world_size, device + + def train_fashion_mnist(): + # Setup distributed training + local_rank, global_rank, world_size, device_type = setup_distributed() + + # Create the model and load it into the device. + if device_type == "cuda" and torch.cuda.is_available(): + device = torch.device(f"{device_type}:{local_rank}") + else: + device = torch.device("cpu") + + # Create model and wrap with DDP only if distributed + net = Net().to(device) + if world_size > 1: + model = nn.parallel.DistributedDataParallel(net) + else: + model = net + optimizer = torch.optim.SGD(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '0.1')), momentum=0.9) + + # Setup checkpointing + checkpoint_dir = os.getenv('CHECKPOINT_DIR', '/workspace/checkpoints') + os.makedirs(checkpoint_dir, exist_ok=True) + + # Resume from checkpoint if available + start_epoch = 1 + if world_size == 1 or dist.get_rank() == 0: + # Find latest checkpoint + checkpoints = glob.glob(os.path.join(checkpoint_dir, 'epoch-*.pth')) + if checkpoints: + def get_epoch_number(checkpoint_path): + try: + filename = os.path.basename(checkpoint_path) + return int(filename.split('-')[1].split('.')[0]) + except (IndexError, ValueError): + return -1 + + latest_checkpoint = max(checkpoints, key=get_epoch_number) + print(f"Resuming from checkpoint: {latest_checkpoint}") + + try: + # Load checkpoint with device mapping for CPU/GPU compatibility + checkpoint = torch.load(latest_checkpoint, map_location=device) + if world_size > 1: + model.module.load_state_dict(checkpoint['model_state_dict']) + else: + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + start_epoch = checkpoint['epoch'] + 1 + print(f"Resumed from epoch {checkpoint['epoch']}, starting epoch {start_epoch}") + except Exception as e: + print(f"Failed to load checkpoint: {e}") + print("Starting training from scratch...") + start_epoch = 1 + + # Broadcast start_epoch to all ranks (only if distributed) + if world_size > 1: + if torch.cuda.is_available(): + start_epoch_tensor = torch.tensor(start_epoch, device=device) + dist.broadcast(start_epoch_tensor, src=0) + start_epoch = start_epoch_tensor.item() + else: + # For CPU training, use a different approach + start_epoch_list = [start_epoch] if dist.get_rank() == 0 else [None] + dist.broadcast_object_list(start_epoch_list, src=0) + start_epoch = start_epoch_list[0] + + # Download FashionMNIST dataset only on local_rank=0 process. + if local_rank == 0: + dataset = datasets.FashionMNIST( + "./data", + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor()]), + ) + if world_size > 1: + dist.barrier() + dataset = datasets.FashionMNIST( + "./data", + train=True, + download=False, + transform=transforms.Compose([transforms.ToTensor()]), + ) + + # Shard the dataset across workers (only if distributed). + if world_size > 1: + train_loader = DataLoader( + dataset, + batch_size=int(os.getenv('BATCH_SIZE', '100')), + sampler=DistributedSampler(dataset) + ) + else: + train_loader = DataLoader( + dataset, + batch_size=int(os.getenv('BATCH_SIZE', '100')), + shuffle=True + ) + + # Initialize progression tracker (only on rank 0) + tracker = None + if world_size == 1 or dist.get_rank() == 0: + num_epochs = int(os.getenv('NUM_EPOCHS', '5')) + steps_per_epoch = len(train_loader) + + # Calculate total epochs for entire training plan + total_epochs_planned = start_epoch + num_epochs - 1 + + tracker = ProgressionTracker( + total_epochs=total_epochs_planned, + steps_per_epoch=steps_per_epoch, + update_interval=int(os.getenv('PROGRESSION_UPDATE_INTERVAL', '10')) + ) + + # Initialize tracker with cumulative progress across entire training lifecycle + if start_epoch > 1: + # Set progress based on completed epochs (cumulative) + completed_epochs = start_epoch - 1 + tracker.current_epoch = completed_epochs + tracker.current_step = completed_epochs * steps_per_epoch + tracker.write_status(f"Training resumed from epoch {start_epoch}") + else: + tracker.current_epoch = 0 + tracker.current_step = 0 + tracker.write_status("Training started") + + if world_size > 1: + dist.barrier() + + # Training loop with progression tracking and checkpointing + num_epochs = int(os.getenv('NUM_EPOCHS', '5')) + for epoch in range(start_epoch, start_epoch + num_epochs): + model.train() + epoch_loss = 0.0 + num_batches = 0 + + # Iterate over mini-batches from the training set + for batch_idx, (inputs, labels) in enumerate(train_loader): + # Copy the data to the GPU device if available + inputs, labels = inputs.to(device), labels.to(device) + + # Forward pass + outputs = model(inputs) + loss = F.nll_loss(outputs, labels) + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Track metrics for epoch average + epoch_loss += loss.item() + num_batches += 1 + + # Update progression (only on rank 0) + if tracker and batch_idx % 10 == 0: + current_lr = optimizer.param_groups[0]["lr"] + + # Calculate samples per second + current_time = time.time() + elapsed_time = current_time - tracker.start_time + total_samples_processed = (epoch - start_epoch) * len(train_loader) * int(os.getenv('BATCH_SIZE', '100')) + batch_idx * int(os.getenv('BATCH_SIZE', '100')) + samples_per_second = total_samples_processed / elapsed_time if elapsed_time > 0 else 0 + + # Calculate accuracy (simple approximation) + with torch.no_grad(): + _, predicted = torch.max(outputs.data, 1) + correct = (predicted == labels).sum().item() + accuracy = correct / labels.size(0) + + # Use absolute epoch for cumulative progress + tracker.update_step( + epoch=epoch, + step=batch_idx, + loss=loss.item(), + learning_rate=current_lr, + checkpoint_dir=checkpoint_dir, + accuracy=accuracy, + world_size=dist.get_world_size(), + local_rank=local_rank, + train_samples_per_second=f"{samples_per_second:.2f}", + train_runtime=f"{elapsed_time:.1f}", + grad_norm=f"{torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0):.4f}" + ) + + if batch_idx % 10 == 0 and (world_size == 1 or dist.get_rank() == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(inputs), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + # Save checkpoint at the end of each epoch (only on rank 0) + if world_size == 1 or dist.get_rank() == 0: + checkpoint_path = os.path.join(checkpoint_dir, f'epoch-{epoch}.pth') + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.module.state_dict() if world_size > 1 else model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': epoch_loss / num_batches if num_batches > 0 else 0.0, + }, checkpoint_path) + print(f"Checkpoint saved: {checkpoint_path}") + + # Update epoch progression (only on rank 0) + if tracker: + avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0 + + # Calculate epoch accuracy (simple approximation) + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for data, target in train_loader: + data, target = data.to(device), target.to(device) + outputs = model(data) + _, predicted = torch.max(outputs.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + + epoch_accuracy = correct / total if total > 0 else 0.0 + + # Use absolute epoch for cumulative progress + tracker.update_epoch( + epoch=epoch, + checkpoint_dir=checkpoint_dir, + avg_loss=avg_loss, + accuracy=epoch_accuracy, + total_batches=num_batches, + total_samples=total + ) + + if world_size > 1: + dist.barrier() + + # Wait for the distributed training to complete + if world_size > 1: + dist.barrier() + + if world_size == 1 or dist.get_rank() == 0: + print("Training is finished") + if tracker: + # Write final completion status + tracker.current_step = tracker.total_steps # Ensure 100% completion + tracker.write_status("Training completed") + + # Buffer time to ensure controller captures 100% completion + print("Waiting for progression status to be captured...") + time.sleep(30) # Buffer time to update the progression status to 100% + + # Finally clean up PyTorch distributed + if world_size > 1: + dist.destroy_process_group() + +if __name__ == "__main__": + train_fashion_mnist() \ No newline at end of file diff --git a/examples/kft-v2/scripts/trl_training.py b/examples/kft-v2/scripts/trl_training.py new file mode 100644 index 000000000..062116f4c --- /dev/null +++ b/examples/kft-v2/scripts/trl_training.py @@ -0,0 +1,790 @@ +def trl_train(): + """TRL training script with distributed coordination and checkpointing.""" + + import os + import json + import time + import signal + import torch + import numpy + from numpy.core.multiarray import _reconstruct + import torch.serialization + import torch.distributed as dist + from datetime import datetime + from pathlib import Path + from typing import Optional + from datasets import load_dataset, load_from_disk + from transformers import ( + AutoTokenizer, + TrainingArguments, + TrainerState, + TrainerControl, + TrainerCallback, + set_seed, + ) + from transformers.trainer_utils import get_last_checkpoint + from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_peft_config, + ) + + torch.serialization.add_safe_globals([_reconstruct, numpy.ndarray, numpy.dtype, numpy.dtypes.UInt32DType]) + + class ProgressionTracker: + """Tracks and writes training progression.""" + + def __init__( + self, + total_epochs: int, + steps_per_epoch: int, + status_file_path: Optional[str] = None, + update_interval: int = 30, + ): + self.total_epochs = total_epochs + self.steps_per_epoch = steps_per_epoch + self.total_steps = total_epochs * steps_per_epoch + self.status_file_path = status_file_path or os.getenv( + "TRAINJOB_PROGRESSION_FILE_PATH", "/tmp/training_progression.json" + ) + self.update_interval = update_interval + self.start_time = time.time() + self.last_update_time = 0 + self.current_epoch = 0 + self.current_step = 0 + self.metrics = {} + + def update_step(self, epoch: int, step: int, loss: float = None, learning_rate: float = None, checkpoint_dir: str = None, **kwargs): + self.current_epoch = epoch + if 'global_step' in kwargs: + self.current_step = int(kwargs['global_step']) + else: + self.current_step = (epoch - 1) * self.steps_per_epoch + step + + training_metrics = {} + generic_metrics = {} + + if loss is not None: + training_metrics["loss"] = str(loss) + if learning_rate is not None: + training_metrics["learning_rate"] = str(learning_rate) + + if checkpoint_dir and os.path.exists(checkpoint_dir): + try: + checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint-') or f.startswith('epoch-')] + if checkpoints: + training_metrics["checkpoints_stored"] = len(checkpoints) + def get_checkpoint_number(name): + try: + return int(name.split('-')[1].split('.')[0]) + except (IndexError, ValueError): + return -1 + latest_checkpoint = os.path.join(checkpoint_dir, max(checkpoints, key=get_checkpoint_number)) + training_metrics["latest_checkpoint_path"] = latest_checkpoint + except (OSError, ValueError): + pass + + for key, value in kwargs.items(): + str_value = str(value) + + if key in ['accuracy', 'train_accuracy']: + training_metrics["accuracy"] = str_value + else: + generic_metrics[key] = str_value + + self.training_metrics = training_metrics + self.generic_metrics = generic_metrics + + current_time = time.time() + if current_time - self.last_update_time >= self.update_interval: + message = f"Training step {self.current_step}/{self.total_steps}" + self.write_status(message) + self.last_update_time = current_time + + def update_epoch(self, epoch: int, checkpoint_dir: str = None, **metrics): + self.current_epoch = epoch + + training_metrics = {} + generic_metrics = {} + + for key, value in metrics.items(): + str_value = str(value) + + if key in ['loss', 'avg_loss', 'train_loss']: + training_metrics["loss"] = str_value + elif key in ['accuracy', 'train_accuracy']: + training_metrics["accuracy"] = str_value + else: + generic_metrics[key] = str_value + + if checkpoint_dir and os.path.exists(checkpoint_dir): + try: + checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint-') or f.startswith('epoch-')] + if checkpoints: + training_metrics["checkpoints_stored"] = len(checkpoints) + def get_checkpoint_number(name): + try: + return int(name.split('-')[1].split('.')[0]) + except (IndexError, ValueError): + return -1 + latest_checkpoint = os.path.join(checkpoint_dir, max(checkpoints, key=get_checkpoint_number)) + training_metrics["latest_checkpoint_path"] = latest_checkpoint + except (OSError, ValueError): + pass + + self.training_metrics = training_metrics + self.generic_metrics = generic_metrics + message = f"Completed epoch {epoch}/{self.total_epochs}" + self.write_status(message) + + def write_status(self, message: str = "Training in progress"): + """Write training status to file.""" + try: + current_time = time.time() + + status_data = { + "message": message, + "timestamp": int(current_time), + "start_time": int(self.start_time), + "current_step": self.current_step, + "total_steps": self.total_steps, + "current_epoch": self.current_epoch, + "total_epochs": self.total_epochs, + } + + if self.total_steps > 0: + percentage = (self.current_step / self.total_steps) * 100 + status_data["percentage_complete"] = f"{percentage:.2f}" + + if self.current_step > 0: + elapsed_time = current_time - self.start_time + time_per_step = elapsed_time / self.current_step + remaining_steps = self.total_steps - self.current_step + eta_seconds = int(remaining_steps * time_per_step) + days, hours, minutes, seconds = eta_seconds // 86400, (eta_seconds % 86400) // 3600, (eta_seconds % 3600) // 60, eta_seconds % 60 + eta_formatted = "" + if days > 0: eta_formatted += f"{days}d" + if hours > 0: eta_formatted += f"{hours}h" + if minutes > 0: eta_formatted += f"{minutes}m" + if seconds > 0 or eta_formatted == "": eta_formatted += f"{seconds}s" + status_data["estimated_time_remaining"] = eta_formatted + + if hasattr(self, 'training_metrics') and self.training_metrics: + status_data["training_metrics"] = self.training_metrics + + if hasattr(self, 'generic_metrics') and self.generic_metrics: + status_data["metrics"] = self.generic_metrics + + temp_file = f"{self.status_file_path}.tmp" + with open(temp_file, "w") as f: + json.dump(status_data, f, indent=2) + os.rename(temp_file, self.status_file_path) + + except Exception as e: + print(f"Failed to write progression status: {e}") + + original_torch_load = torch.load + def patched_torch_load(*args, **kwargs): + if 'weights_only' not in kwargs: kwargs['weights_only'] = False + if 'map_location' not in kwargs: kwargs['map_location'] = 'cuda' if torch.cuda.is_available() else 'cpu' + return original_torch_load(*args, **kwargs) + torch.load = patched_torch_load + + class DistributedCheckpointCallback(TrainerCallback): + def __init__(self, output_dir: str, progression_tracker: Optional[ProgressionTracker] = None): + self.output_dir = output_dir + self.checkpoint_requested = False + self.save_triggered = False + self.checkpoint_stream = None + self.sigterm_tensor = None + self.progression_tracker = progression_tracker + + self.checkpoint_enabled = os.environ.get('CHECKPOINT_ENABLED', 'false').lower() == 'true' + self.checkpoint_uri = os.environ.get('CHECKPOINT_URI', '/workspace/checkpoints') + + self.progress_file = os.environ.get('TRAINING_PROGRESS_FILE', '/workspace/training_progress.json') + + + def _log_message(self, message: str): + print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {message}") + + def _write_progress(self, state: TrainerState): + rank = int(os.environ.get('RANK', '0')) + if rank != 0: + return + + if self.progression_tracker: + try: + latest_loss = 0.0 + latest_lr = 0.0 + if state.log_history: + latest_log = state.log_history[-1] + latest_loss = latest_log.get('loss', latest_log.get('train_loss', latest_log.get('training_loss', 0.0))) + latest_lr = latest_log.get('learning_rate', latest_log.get('lr', latest_log.get('train_lr', 0.0))) + + epoch = max(1, int(state.epoch)) if state.epoch is not None else 1 + step_in_epoch = (state.global_step - 1) % self.progression_tracker.steps_per_epoch if self.progression_tracker.steps_per_epoch > 0 else 0 + + self.progression_tracker.update_step( + epoch=epoch, + step=step_in_epoch, + loss=latest_loss, + learning_rate=latest_lr, + checkpoint_dir=self.output_dir, + global_step=state.global_step, + max_steps=state.max_steps, + num_train_epochs=state.num_train_epochs + ) + except Exception as e: + print(f"ProgressionTracker update failed: {e}") + self._write_simple_progress(state) + else: + self._write_simple_progress(state) + + def _write_simple_progress(self, state: TrainerState): + try: + latest_loss = latest_lr = 0.0 + if state.log_history: + latest_log = state.log_history[-1] + latest_loss = latest_log.get('loss', latest_log.get('train_loss', latest_log.get('training_loss', 0.0))) + latest_lr = latest_log.get('learning_rate', latest_log.get('lr', latest_log.get('train_lr', 0.0))) + + progress_data = { + "epoch": int(state.epoch) if state.epoch else 1, + "totalEpochs": int(state.num_train_epochs) if state.num_train_epochs else 1, + "step": state.global_step, + "totalSteps": state.max_steps, + "loss": f"{latest_loss:.4f}", + "learningRate": f"{latest_lr:.6f}", + "percentComplete": f"{(state.global_step / state.max_steps * 100):.1f}" if state.max_steps > 0 else "0.0", + "lastUpdateTime": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + } + + temp_file = self.progress_file + '.tmp' + with open(temp_file, 'w') as f: + json.dump(progress_data, f, indent=2) + os.rename(temp_file, self.progress_file) + os.chmod(self.progress_file, 0o644) + + except Exception as e: + pass + + def _init_distributed_signal_tensor(self): + try: + if dist.is_initialized(): + device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu') + self.sigterm_tensor = torch.zeros(1, dtype=torch.float32, device=device) + self._log_message(f"Initialized distributed SIGTERM tensor on device: {device}") + else: + self._log_message("Distributed training not initialized - using local SIGTERM handling only") + except Exception as e: + self._log_message(f"Failed to initialize distributed SIGTERM tensor: {e}. Using local handling only.") + + def _check_distributed_sigterm(self): + try: + if dist.is_initialized() and self.sigterm_tensor is not None: + dist.all_reduce(self.sigterm_tensor, op=dist.ReduceOp.MAX) + return self.sigterm_tensor.item() > 0.5 + except Exception as e: + self._log_message(f"Distributed SIGTERM check failed: {e}. Using local signal only.") + return self.checkpoint_requested + + def _sigterm_handler(self, signum, frame): + rank = os.environ.get("RANK", "-1") + self._log_message(f"Rank {rank}: SIGTERM received, flagging for distributed checkpoint.") + self.checkpoint_requested = True + if self.sigterm_tensor is not None: + self.sigterm_tensor.fill_(1.0) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + rank = os.environ.get("RANK", "-1") + os.makedirs(self.output_dir, exist_ok=True) + self._init_distributed_signal_tensor() + + if torch.cuda.is_available(): + self.checkpoint_stream = torch.cuda.Stream() + self._log_message(f"Rank {rank}: Created dedicated CUDA stream for checkpointing.") + + signal.signal(signal.SIGTERM, self._sigterm_handler) + self._log_message(f"Rank {rank}: Distributed SIGTERM handler registered.") + + try: + if dist.is_initialized(): + dist.barrier() + self._log_message(f"Rank {rank}: Distributed coordination setup synchronized across all ranks") + except Exception as e: + self._log_message(f"Rank {rank}: Failed to synchronize distributed setup: {e}") + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if state.global_step % args.logging_steps == 0: + self._write_progress(state) + + if self.progression_tracker and state.global_step % max(1, args.logging_steps // 2) == 0: + rank = int(os.environ.get('RANK', '0')) + if rank == 0: + # Extract current metrics + latest_loss = 0.0 + latest_lr = 0.0 + if state.log_history: + latest_log = state.log_history[-1] + latest_loss = latest_log.get('loss', latest_log.get('train_loss', latest_log.get('training_loss', 0.0))) + latest_lr = latest_log.get('learning_rate', latest_log.get('lr', latest_log.get('train_lr', 0.0))) + + epoch = max(1, int(state.epoch)) if state.epoch is not None else 1 + step_in_epoch = (state.global_step - 1) % self.progression_tracker.steps_per_epoch if self.progression_tracker.steps_per_epoch > 0 else 0 + + current_time = time.time() + elapsed_time = current_time - self.progression_tracker.start_time + + batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps + if int(os.environ.get('WORLD_SIZE', '1')) > 1: + batch_size *= int(os.environ.get('WORLD_SIZE', '1')) + + total_samples_processed = state.global_step * batch_size + samples_per_second = total_samples_processed / elapsed_time if elapsed_time > 0 else 0 + + self.progression_tracker.update_step( + epoch=epoch, + step=step_in_epoch, + loss=latest_loss, + learning_rate=latest_lr, + checkpoint_dir=self.output_dir, + global_step=state.global_step, + max_steps=state.max_steps, + train_samples_per_second=f"{samples_per_second:.2f}", + train_runtime=f"{elapsed_time:.1f}", + world_size=os.environ.get('WORLD_SIZE', '1'), + local_rank=os.environ.get('LOCAL_RANK', '0') + ) + + if self._check_distributed_sigterm() and not self.save_triggered: + rank = os.environ.get("RANK", "-1") + self._log_message(f"Rank {rank}: Distributed SIGTERM detected, initiating checkpoint at step {state.global_step}.") + self.save_triggered = True + control.should_save = True + control.should_training_stop = True + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self._write_progress(state) + + if self.progression_tracker: + rank = int(os.environ.get('RANK', '0')) + if rank == 0: + self.progression_tracker.current_step = self.progression_tracker.total_steps + self.progression_tracker.write_status("Training completed") + + rank = os.environ.get("RANK", "-1") + if rank == "0" and self.checkpoint_requested: + self._log_message(f"Rank {rank}: Training ended due to distributed SIGTERM checkpoint request.") + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.progression_tracker: + rank = int(os.environ.get('RANK', '0')) + if rank == 0: + epoch = max(1, int(state.epoch)) if state.epoch is not None else 1 + latest_loss = 0.0 + if state.log_history: + latest_log = state.log_history[-1] + latest_loss = latest_log.get('loss', latest_log.get('train_loss', latest_log.get('training_loss', 0.0))) + + self.progression_tracker.update_epoch( + epoch=epoch, + checkpoint_dir=self.output_dir, + avg_loss=latest_loss, + global_step=state.global_step, + max_steps=state.max_steps + ) + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + rank = os.environ.get("RANK", "-1") + if rank == "0": + if self.progression_tracker: + try: + trainer_state_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}", 'trainer_state.json') + if os.path.exists(trainer_state_path): + with open(trainer_state_path, 'r') as f: + trainer_state_data = json.load(f) + + trainer_state_data['training_start_time'] = self.progression_tracker.start_time + + with open(trainer_state_path, 'w') as f: + json.dump(trainer_state_data, f, indent=2) + + self._log_message(f"Rank {rank}: Saved training start time to checkpoint.") + except Exception as e: + self._log_message(f"Rank {rank}: Failed to save training start time: {e}") + + self._log_message(f"Rank {rank}: Checkpoint save completed.") + if self.checkpoint_requested: + self._log_message(f"Rank {rank}: Distributed SIGTERM-triggered checkpoint save finished successfully.") + + def setup_distributed(): + """Initialize distributed training.""" + node_rank = int(os.getenv('PET_NODE_RANK', '0')) + num_nodes = int(os.getenv('PET_NNODES', '1')) + nproc_per_node = int(os.getenv('PET_NPROC_PER_NODE', '1')) + master_addr = os.getenv('PET_MASTER_ADDR', 'localhost') + master_port = os.getenv('PET_MASTER_PORT', '29500') + + local_rank = int(os.getenv('LOCAL_RANK', '0')) + world_size = num_nodes * nproc_per_node + global_rank = node_rank * nproc_per_node + local_rank + + os.environ['RANK'] = str(global_rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + + if world_size > 1: + try: + torch.distributed.init_process_group( + backend='gloo', + rank=global_rank, + world_size=world_size + ) + torch.distributed.barrier() + except Exception as e: + print(f"Warning: Failed to initialize distributed training: {e}") + + return local_rank, global_rank, world_size + + def load_dataset_from_initializer(): + """Load dataset from initializer or download.""" + dataset_dir = Path("/workspace/dataset") + + if dataset_dir.exists() and any(dataset_dir.iterdir()): + try: + full_dataset = load_from_disk(str(dataset_dir)) + if isinstance(full_dataset, dict): + train_dataset = full_dataset.get('train', full_dataset.get('train[:100]')) + test_dataset = full_dataset.get('test', full_dataset.get('test[:20]')) + else: + train_size = min(100, len(full_dataset) - 20) + train_dataset = full_dataset.select(range(train_size)) + test_dataset = full_dataset.select(range(train_size, min(train_size + 20, len(full_dataset)))) + + return train_dataset, test_dataset + except Exception as e: + print(f"Failed to load from initializer: {e}") + + dataset_name = os.getenv('DATASET_NAME', 'tatsu-lab/alpaca') + train_split = os.getenv('DATASET_TRAIN_SPLIT', 'train[:100]') + test_split = os.getenv('DATASET_TEST_SPLIT', 'train[100:120]') + + train_dataset = load_dataset(dataset_name, split=train_split) + test_dataset = load_dataset(dataset_name, split=test_split) + + return train_dataset, test_dataset + + def load_model_from_initializer(): + """Load model and tokenizer.""" + model_dir = Path("/workspace/model") + + if model_dir.exists() and any(model_dir.iterdir()): + model_path = str(model_dir) + else: + model_path = os.getenv('MODEL_NAME', 'gpt2') + + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if tokenizer.chat_template is None: + tokenizer.chat_template = ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}" + "### Instruction:\n{{ message['content'] }}\n" + "{% elif message['role'] == 'assistant' %}" + "### Response:\n{{ message['content'] }}{{ eos_token }}\n" + "{% endif %}" + "{% endfor %}" + ) + + return model_path, tokenizer + + except Exception as e: + print(f"Error loading model: {e}") + model_path = 'gpt2' + tokenizer = AutoTokenizer.from_pretrained(model_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return model_path, tokenizer + + def prepare_datasets(train_dataset, test_dataset, tokenizer): + """Prepare datasets for training.""" + def template_dataset(sample): + if 'instruction' in sample and 'output' in sample: + messages = [ + {"role": "user", "content": sample['instruction']}, + {"role": "assistant", "content": sample['output']}, + ] + elif 'question' in sample and 'answer' in sample: + messages = [ + {"role": "user", "content": sample['question']}, + {"role": "assistant", "content": sample['answer']}, + ] + else: + content = str(sample.get('text', sample.get('content', 'Sample text'))) + messages = [ + {"role": "user", "content": "Complete this text:"}, + {"role": "assistant", "content": content}, + ] + + return {"text": tokenizer.apply_chat_template(messages, tokenize=False)} + + train_columns = list(train_dataset.features.keys()) + train_columns.remove('text') if 'text' in train_columns else None + + train_dataset = train_dataset.map(template_dataset, remove_columns=train_columns) + + if test_dataset is not None: + test_columns = list(test_dataset.features.keys()) + test_columns.remove('text') if 'text' in test_columns else None + test_dataset = test_dataset.map(template_dataset, remove_columns=test_columns) + + return train_dataset, test_dataset + + def get_training_parameters(): + """Get training parameters.""" + checkpoint_dir = Path(os.getenv('CHECKPOINT_URI', '/workspace/checkpoints')) + checkpoint_enabled = os.getenv('CHECKPOINT_ENABLED', 'false').lower() == 'true' + checkpoint_interval = os.getenv('CHECKPOINT_INTERVAL', '30s') + max_checkpoints = int(os.getenv('CHECKPOINT_MAX_RETAIN', '5')) + + parameters = { + 'model_name_or_path': os.getenv('MODEL_NAME', 'gpt2'), + 'model_revision': 'main', + 'torch_dtype': 'bfloat16', + 'use_peft': True, + 'lora_r': int(os.getenv('LORA_R', '16')), + 'lora_alpha': int(os.getenv('LORA_ALPHA', '32')), + 'lora_dropout': float(os.getenv('LORA_DROPOUT', '0.1')), + 'lora_target_modules': ['c_attn', 'c_proj'], # GPT-2 specific + 'dataset_name': os.getenv('DATASET_NAME', 'tatsu-lab/alpaca'), + 'dataset_config': 'main', + 'dataset_train_split': os.getenv('DATASET_TRAIN_SPLIT', 'train[:100]'), + 'dataset_test_split': os.getenv('DATASET_TEST_SPLIT', 'train[100:120]'), + 'num_train_epochs': int(os.getenv('MAX_EPOCHS', '3')), + 'per_device_train_batch_size': int(os.getenv('BATCH_SIZE', '2')), + 'per_device_eval_batch_size': int(os.getenv('BATCH_SIZE', '2')), + 'eval_strategy': 'steps', + 'eval_steps': int(os.getenv('EVAL_STEPS', '25')), + 'bf16': torch.cuda.is_available(), # Only use bf16 if CUDA is available + 'fp16': not torch.cuda.is_available(), # Use fp16 for CPU training + 'learning_rate': float(os.getenv('LEARNING_RATE', '5e-5')), + 'warmup_steps': int(os.getenv('WARMUP_STEPS', '10')), + 'lr_scheduler_type': 'cosine', + 'optim': 'adamw_torch', + 'max_grad_norm': 1.0, + 'seed': 42, + 'gradient_accumulation_steps': int(os.getenv('GRADIENT_ACCUMULATION_STEPS', '4')), + 'save_strategy': 'steps', + 'save_steps': int(os.getenv('SAVE_STEPS', '20')), + 'save_total_limit': max_checkpoints if checkpoint_enabled else None, + 'logging_strategy': 'steps', + 'logging_steps': int(os.getenv('LOGGING_STEPS', '5')), + 'report_to': [], + 'output_dir': str(checkpoint_dir), + # Fix DDP parameter marking issue with LoRA + 'gradient_checkpointing': False, # Disable gradient checkpointing to avoid DDP conflicts + 'ddp_find_unused_parameters': False, # Optimize DDP performance + 'ddp_backend': 'gloo', # Use gloo backend for better LoRA compatibility + 'dataloader_pin_memory': False, # Disable pin memory for CPU training + } + + + return parameters + + + """Training function.""" + + import os + + local_rank, global_rank, world_size = setup_distributed() + + if world_size > 1: + try: + if dist.is_initialized(): + dist.barrier() + except Exception as e: + print(f"Warning: Failed to synchronize distributed setup: {e}") + + os.makedirs("/workspace/cache/transformers", exist_ok=True) + os.makedirs("/workspace/cache", exist_ok=True) + os.makedirs("/workspace/cache/datasets", exist_ok=True) + + parameters = get_training_parameters() + checkpoint_dir = Path(parameters['output_dir']) + os.makedirs(checkpoint_dir, exist_ok=True) + + + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_dict(parameters) + + set_seed(training_args.seed) + + model_path, tokenizer = load_model_from_initializer() + train_dataset, test_dataset = load_dataset_from_initializer() + train_dataset, test_dataset = prepare_datasets(train_dataset, test_dataset, tokenizer) + + progression_tracker = None + + callbacks = [ + DistributedCheckpointCallback(str(checkpoint_dir), progression_tracker) + ] + + # Fix DDP parameter marking issue with LoRA + if world_size > 1: + # Set static graph for DDP to avoid parameter marking issues + os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' + + trainer = SFTTrainer( + model=model_path, + args=training_args, + train_dataset=train_dataset, + eval_dataset=test_dataset, + peft_config=get_peft_config(model_args), + processing_class=tokenizer, + callbacks=callbacks, + ) + + # Apply DDP fix after trainer initialization + if world_size > 1: + try: + # Check if model is wrapped in DDP + if hasattr(trainer.model, 'module'): + # Set static graph to prevent parameter marking issues + trainer.model._set_static_graph() + print(f"Applied DDP static graph fix for distributed training") + elif hasattr(trainer.accelerator.unwrap_model(trainer.model), 'base_model'): + # For PEFT models, ensure proper DDP handling + unwrapped_model = trainer.accelerator.unwrap_model(trainer.model) + if hasattr(unwrapped_model, '_ddp_params_and_buffers_to_ignore'): + print(f"PEFT model detected, DDP parameters properly configured") + except Exception as e: + print(f"Warning: Could not apply DDP static graph fix: {e}") + + if trainer.accelerator.is_main_process and hasattr(trainer.model, "print_trainable_parameters"): + trainer.model.print_trainable_parameters() + + checkpoint = get_last_checkpoint(training_args.output_dir) + resume_from_epoch = 0 + resume_from_step = 0 + + if checkpoint is not None: + try: + checkpoint_files = os.listdir(checkpoint) + if 'trainer_state.json' not in checkpoint_files: + checkpoint = None + else: + trainer_state_path = os.path.join(checkpoint, 'trainer_state.json') + if os.path.exists(trainer_state_path): + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + resume_from_epoch = int(trainer_state.get('epoch', 0)) + resume_from_step = int(trainer_state.get('global_step', 0)) + print(f"Resuming from checkpoint: epoch {resume_from_epoch}, step {resume_from_step}") + except Exception as e: + print(f"Checkpoint validation failed: {e}") + checkpoint = None + + if world_size == 1 or global_rank == 0: + train_dataset_size = len(train_dataset) if train_dataset else 1000 # fallback + batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps + if world_size > 1: + batch_size *= world_size + steps_per_epoch = max(1, train_dataset_size // batch_size) + + num_epochs = int(training_args.num_train_epochs) + # Fix: Total epochs should always be the configured number, not additive + total_epochs_planned = num_epochs + + progression_tracker = ProgressionTracker( + total_epochs=total_epochs_planned, + steps_per_epoch=steps_per_epoch, + update_interval=int(os.getenv('PROGRESSION_UPDATE_INTERVAL', '10')) + ) + + if checkpoint is not None and (resume_from_epoch > 0 or resume_from_step > 0): + # Set current progress based on what was actually completed + progression_tracker.current_epoch = resume_from_epoch + # Use global_step if available, otherwise calculate from epochs + if resume_from_step > 0: + progression_tracker.current_step = resume_from_step + else: + # Calculate steps from completed epochs + progression_tracker.current_step = resume_from_epoch * steps_per_epoch + + try: + trainer_state_path = os.path.join(checkpoint, 'trainer_state.json') + if os.path.exists(trainer_state_path): + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + if 'training_start_time' in trainer_state: + progression_tracker.start_time = trainer_state['training_start_time'] + print(f"Restored original training start time from checkpoint") + else: + current_time = time.time() + if progression_tracker.current_step > 0 and progression_tracker.total_steps > 0: + # Better estimation: assume average 2 seconds per step (more realistic) + estimated_elapsed = progression_tracker.current_step * 2.0 + progression_tracker.start_time = current_time - estimated_elapsed + print(f"Estimated original training start time based on completed steps: {progression_tracker.current_step} steps") + except Exception as e: + print(f"Could not restore training start time: {e}") + + progression_tracker.write_status(f"Training resumed from epoch {resume_from_epoch}") + else: + progression_tracker.current_epoch = 0 + progression_tracker.current_step = 0 + progression_tracker.write_status("Training started") + + if callbacks and len(callbacks) > 0: + callbacks[0].progression_tracker = progression_tracker + + if world_size > 1: + try: + if dist.is_initialized(): + dist.barrier() + except Exception as e: + print(f"Warning: Failed to synchronize distributed processes: {e}") + + try: + trainer.train(resume_from_checkpoint=checkpoint) + except Exception as e: + print(f"Training failed: {e}") + if checkpoint is not None: + try: + if progression_tracker: + progression_tracker.current_epoch = 0 + progression_tracker.current_step = 0 + progression_tracker.write_status("Training restarted from scratch after checkpoint failure") + trainer.train(resume_from_checkpoint=None) + except Exception as retry_e: + print(f"Training failed even from scratch: {retry_e}") + raise retry_e + else: + raise + + trainer.save_model(training_args.output_dir) + + if progression_tracker and (world_size == 1 or global_rank == 0): + progression_tracker.current_step = progression_tracker.total_steps + progression_tracker.write_status("Training completed successfully") + + print("Waiting for progression status to be captured...") + time.sleep(30) + + if world_size > 1: + try: + if dist.is_initialized(): + dist.destroy_process_group() + except Exception as e: + print(f"Warning: Failed to cleanup distributed process group: {e}") + + \ No newline at end of file diff --git a/examples/kft-v2/trl-gpt2-checkpointing.ipynb b/examples/kft-v2/trl-gpt2-checkpointing.ipynb new file mode 100644 index 000000000..47e0f7b57 --- /dev/null +++ b/examples/kft-v2/trl-gpt2-checkpointing.ipynb @@ -0,0 +1,450 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ec6b8bb7", + "metadata": {}, + "source": [ + "## TRL(Transformer Reinforcement Learning) Training with Kubeflow SDK and Advanced Checkpointing\n", + "\n", + "This notebook demonstrates how to use the Kubeflow Trainer SDK to create and manage TrainJobs\n", + "\n", + "### Features Demonstrated\n", + "- **Kubeflow SDK Integration**: Programmatic TrainJob creation and management\n", + "- **Checkpointing**: Controller-managed resume/suspended compatibility for model checkpoints\n", + "- **TRL SFTTrainer**: Supervised fine-tuning using Peft-LoRA with GPT-2 and Alpaca dataset for instruction following\n", + "- **Distributed Training**: Multi-node Multi-GPU coordination\n", + "- **Compute resource pre-requisite for this demo** : \n", + " This demo can run on -\n", + " - CPUs based training using GLOO backend (default configuration)\n", + " - GPUs(NVIDIA/AMD) based training using NCCL backend\n", + " - Respective training images (update in [torch-cuda-custom](./cluster_training_runtime.yaml)):\n", + " - quay.io/modh/training:py311-cuda124-torch251\n", + " - quay.io/modh/training:py311-rocm62-torch251\n", + " - Multi-node Multi-GPU distributed training using Trainer V2 MlPolicies (NumNodes/NProcPerNodes)\n", + "\n", + "### Prerequisites\n", + "- Persistent volume storage with RWX(ReadWriteManyAccess) : [workspace](workspace-checkpoint-storage)\n", + "- ClusterTrainingRuntime : [torch-cuda-custom](./cluster_training_runtime.yaml)\n", + "\n", + "### Sample scripts\n", + "- [mnist.py](./scripts/mnist.py)\n", + "- [trl_training.py](./scripts/trl_training.py)\n", + "- _oc apply -k examples/kft-v2/manifests_\n", + "\n", + "### References\n", + "- [Kubeflow Trainer SDK](https://github.com/kubeflow/sdk)\n", + "- [TRL Documentation](https://huggingface.co/docs/trl/)\n", + "- [PEFT Documentation](https://huggingface.co/docs/peft/)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91c67d2f", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Install Kubeflow SDK from source github main branch\n", + "%pip install kubeflow" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fa66b8bd-47f4-4f72-9332-50db5d827d37", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Name: kubeflow\n", + "Version: 0.1.0\n", + "Summary: Kubeflow Python SDK to manage ML workloads and to interact with Kubeflow APIs.\n", + "Home-page: https://github.com/kubeflow/sdk\n", + "Author: \n", + "Author-email: The Kubeflow Authors \n", + "License: \n", + "Location: /opt/app-root/lib64/python3.12/site-packages\n", + "Requires: kubeflow-trainer-api, kubernetes, pydantic\n", + "Required-by: \n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip show kubeflow" + ] + }, + { + "cell_type": "markdown", + "id": "4f97773c", + "metadata": {}, + "source": [ + "### Define TRL Training Function\n", + "- Progress file writer (callbacks)\n", + "- Distributed checkpoint coordination\n", + "- Automated model checkpointing by SIGTERM signal handling" + ] + }, + { + "cell_type": "markdown", + "id": "6c002e5e", + "metadata": {}, + "source": [ + "### Create TrainJob Using Kubeflow SDK\n", + "Now we'll use the Kubeflow SDK to create a TrainJob\n", + "- Training arguments\n", + "- *CustomTrainer* with the TRL training function\n", + "- *Initializer* for dataset and model (V2 initializers)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6b94b783", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training configuration initialised!\n" + ] + } + ], + "source": [ + "from kubeflow.trainer import CustomTrainer, Initializer\n", + "\n", + "training_env_args = {\n", + " \"PYTHONUNBUFFERED\": \"1\",\n", + " \"NCCL_DEBUG\": \"INFO\",\n", + " \"TORCH_DISTRIBUTED_DEBUG\": \"INFO\",\n", + " \"PYTHONPATH\": \"/tmp/lib:$PYTHONPATH\",\n", + "\n", + " # Training hyperparameters\n", + " \"LEARNING_RATE\": \"5e-5\",\n", + " \"BATCH_SIZE\": \"1\",\n", + " \"MAX_EPOCHS\": \"3\",\n", + " \"WARMUP_STEPS\": \"5\",\n", + " \"EVAL_STEPS\": \"3\",\n", + " \"SAVE_STEPS\": \"2\",\n", + " \"LOGGING_STEPS\": \"2\",\n", + " \"GRADIENT_ACCUMULATION_STEPS\": \"2\",\n", + " \n", + " # Model configuration\n", + " \"MODEL_NAME\": \"gpt2\",\n", + " \"LORA_R\": \"16\",\n", + " \"LORA_ALPHA\": \"32\",\n", + " \"LORA_DROPOUT\": \"0.1\",\n", + " \"MAX_SEQ_LENGTH\": \"512\",\n", + " \n", + " # Dataset configuration\n", + " \"DATASET_NAME\": \"tatsu-lab/alpaca\",\n", + " \"DATASET_TRAIN_SPLIT\": \"train[:500]\",\n", + " \"DATASET_TEST_SPLIT\": \"train[500:520]\",\n", + " \n", + " # Checkpointing configuration\n", + " \"CHECKPOINT_URI\": \"/workspace/checkpoints\",\n", + " \"TRAINJOB_PROGRESSION_FILE_PATH\": \"/tmp/training_progression.json\",\n", + " \n", + " # Cache directories\n", + " \"PYTHONUNBUFFERED\": \"1\",\n", + " \"TRANSFORMERS_CACHE\": \"/workspace/cache/transformers\",\n", + " \"HF_HOME\": \"/workspace/cache\",\n", + " \"HF_DATASETS_CACHE\": \"/workspace/cache/datasets\",\n", + " \n", + " # Distributed training debug\n", + " \"NCCL_DEBUG\": \"INFO\",\n", + " \"NCCL_DEBUG_SUBSYS\": \"ALL\",\n", + " \"NCCL_SOCKET_IFNAME\": \"eth0\",\n", + " \"NCCL_IB_DISABLE\": \"1\",\n", + " \"NCCL_P2P_DISABLE\": \"1\",\n", + " \"NCCL_TREE_THRESHOLD\": \"0\",\n", + " \"TORCH_DISTRIBUTED_DEBUG\": \"INFO\",\n", + " \"TORCH_SHOW_CPP_STACKTRACES\": \"1\",\n", + "}\n", + "\n", + "from trl_training import trl_train\n", + "\n", + "# Create CustomTrainer configuration\n", + "custom_trainer = CustomTrainer(\n", + " func=trl_train,\n", + " num_nodes=2, # Distributed training across 2 nodes\n", + " resources_per_node={\n", + " \"cpu\": \"2\",\n", + " \"memory\": \"4Gi\",\n", + " # Uncomment for GPU training:\n", + " # \"nvidia.com/gpu\": \"1\",\n", + " },\n", + " packages_to_install=[\n", + " \"transformers[torch]\",\n", + " \"trl\", \n", + " \"peft\", \n", + " \"datasets\", \n", + " \"accelerate\",\n", + " \"torch\",\n", + " \"numpy\"\n", + " \" --target=/tmp/lib\"\n", + " \" --verbose\"\n", + " ],\n", + " env=training_env_args\n", + ")\n", + "from kubeflow.trainer.types import types\n", + "\n", + "# Configure Initializers\n", + "initializer = Initializer(\n", + " dataset=types.HuggingFaceDatasetInitializer(\n", + " storage_uri=\"hf://tatsu-lab/alpaca\"\n", + " ),\n", + " model=types.HuggingFaceModelInitializer(\n", + " storage_uri=\"hf://gpt2\"\n", + " )\n", + ")\n", + "\n", + "print(\"Training configuration initialised!\")" + ] + }, + { + "cell_type": "markdown", + "id": "4e632ece", + "metadata": {}, + "source": [ + "### Initialize Trainer Client\n", + "Use token authentication to intialize a training client and list available runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "31077051", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Available runtimes : 5\n", + "- torch-cuda-241\n", + "- torch-cuda-251\n", + "- torch-cuda-custom\n", + "- torch-rocm-241\n", + "- torch-rocm-251\n" + ] + } + ], + "source": [ + "from kubeflow.trainer import TrainerClient\n", + "from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig\n", + "from kubernetes import client\n", + "\n", + "api_server = \"\"\n", + "token = \"\"\n", + "\n", + "configuration = client.Configuration()\n", + "configuration.host = api_server\n", + "configuration.api_key = {\"authorization\": f\"Bearer {token}\"}\n", + "\n", + "# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA\n", + "configuration.verify_ssl = False\n", + "\n", + "api_client = client.ApiClient(configuration)\n", + "trainer_client = TrainerClient(backend_config= KubernetesBackendConfig(client_configuration=api_client.configuration))\n", + "\n", + "print(\"Available runtimes :\", len(trainer_client.list_runtimes()))\n", + "for r in trainer_client.list_runtimes():\n", + " print(f\"- {r.name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "373b1160", + "metadata": {}, + "source": [ + "### Create TrainJob\n", + "Create a TrainJob using resources declared above - \n", + "- Custom trainer\n", + "- Dataset & Model initailisers " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ab961c44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trainjob submitted!!\n" + ] + } + ], + "source": [ + "job_name = trainer_client.train(\n", + " trainer=custom_trainer,\n", + " initializer=initializer,\n", + " runtime=trainer_client.get_runtime(\"torch-cuda-custom\")\n", + ")\n", + "print(\"Trainjob submitted!!\")" + ] + }, + { + "cell_type": "markdown", + "id": "dd299d02", + "metadata": {}, + "source": [ + "![pods](./docs/trainjobs_jobsets.png)\n" + ] + }, + { + "cell_type": "markdown", + "id": "440c7242", + "metadata": {}, + "source": [ + "![jobs](./docs/jobs.png)" + ] + }, + { + "cell_type": "markdown", + "id": "805900d1", + "metadata": {}, + "source": [ + "### Start monitoring - View Training Logs " + ] + }, + { + "cell_type": "markdown", + "id": "5613e350", + "metadata": {}, + "source": [ + "![trainjob_pods](./docs/trainjob_pods.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96e2a0b4", + "metadata": {}, + "outputs": [], + "source": [ + "# Get training logs\n", + "try: \n", + " # Get logs from the training nodes\n", + " logs = trainer_client.get_job_logs(job_name, follow=False)\n", + " \n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"TRAINING LOGS\")\n", + " print(\"=\"*80)\n", + " \n", + " # Display logs - logs is a generator, not a dict\n", + " for log_line in logs:\n", + " if log_line.strip():\n", + " print(log_line)\n", + " \n", + " print(\"\\n\" + \"=\"*80)\n", + " \n", + "except Exception as e:\n", + " print(f\"Error getting logs: {e}\")\n", + " print(\"Note: Logs may not be available yet if training is still starting up\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0dae52b", + "metadata": {}, + "source": [ + "### Cleanup resources" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b5a2b3e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final TrainJob Status:\n", + " Name: d5648a3bd444\n", + " Status: Complete\n", + " Created: 2025-10-06 15:08:34+00:00\n", + " Nodes: 2\n", + " Runtime: torch-cuda-custom\n", + " Steps:\n", + " - dataset-initializer: Succeeded\n", + " - model-initializer: Succeeded\n", + " - node-0: Succeeded\n", + " - node-1: Succeeded\n", + "\n", + "TrainJob 'd5648a3bd444' deleted successfully\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TrainJob 'hbfe180e23f8' deleted successfully\n" + ] + } + ], + "source": [ + "# Clean up the TrainJob when done\n", + "def cleanup_trainjob():\n", + " \"\"\"Clean up the TrainJob using Kubeflow SDK\"\"\"\n", + " try:\n", + " trainer_client.delete_job(job_name)\n", + " print(f\"TrainJob '{job_name}' deleted successfully\")\n", + " except Exception as e:\n", + " print(f\"Error deleting TrainJob: {e}\")\n", + "\n", + "# Get final job status before cleanup\n", + "try:\n", + " final_job = trainer_client.get_job(job_name)\n", + " print(f\"Final TrainJob Status:\")\n", + " print(f\" Name: {final_job.name}\")\n", + " print(f\" Status: {final_job.status}\")\n", + " print(f\" Created: {final_job.creation_timestamp}\")\n", + " print(f\" Nodes: {final_job.num_nodes}\")\n", + " print(f\" Runtime: {final_job.runtime.name}\")\n", + " \n", + " if final_job.steps:\n", + " print(f\" Steps:\")\n", + " for step in final_job.steps:\n", + " print(f\" - {step.name}: {step.status}\")\n", + " print()\n", + " cleanup_trainjob()\n", + " \n", + "except Exception as e:\n", + " print(f\"Error getting final job status: {e}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.12", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}