diff --git a/F2LLM/configs/ray_config.yaml b/F2LLM/configs/ray_config.yaml new file mode 100644 index 0000000..6396028 --- /dev/null +++ b/F2LLM/configs/ray_config.yaml @@ -0,0 +1,144 @@ +# Ray Distributed Training Configuration +# This configuration file defines settings for Ray Train distributed training + +# ============================================================================= +# Ray Cluster Configuration +# ============================================================================= +cluster: + # Ray cluster address + # - "auto": Initialize local Ray cluster automatically + # - "ray://:10001": Connect to remote Ray cluster + ray_address: "auto" + +# ============================================================================= +# Scaling Configuration +# ============================================================================= +scaling: + # Number of training workers (typically equals number of GPUs) + num_workers: 8 + + # Whether to use GPU for training + use_gpu: true + + # Resources allocated per worker + resources_per_worker: + CPU: 4 # CPU cores per worker + GPU: 1 # GPUs per worker (1 for single-GPU workers) + +# ============================================================================= +# Training Configuration +# ============================================================================= +training: + # PyTorch distributed backend + # - "nccl": For GPU training (recommended) + # - "gloo": For CPU training + backend: "nccl" + + # DeepSpeed integration + use_deepspeed: true + + # DeepSpeed configuration (maps to Accelerate's deepspeed_config) + deepspeed_config: + # ZeRO optimization configuration + zero_optimization: + stage: 2 # ZeRO-2 optimization (1, 2, or 3) + # Note: ZeRO-3 requires additional memory management + + # Mixed precision training + bf16: + enabled: true # Use BFloat16 for mixed precision + + # Gradient configuration + gradient_clipping: 1.0 + gradient_accumulation_steps: 1 + + # Batch size per GPU (should match train_batch_size below) + train_micro_batch_size_per_gpu: 8 + +# ============================================================================= +# Fault Tolerance Configuration +# ============================================================================= +fault_tolerance: + # Maximum number of failures before giving up + max_failures: 3 + + # Number of checkpoints to keep + checkpoint_num_to_keep: 3 + + # Checkpoint selection criteria + checkpoint_score_attribute: "loss" # Metric to use for checkpoint selection + checkpoint_score_order: "min" # "min" for loss, "max" for accuracy + +# ============================================================================= +# Model and Data Configuration +# ============================================================================= +# Model path (HuggingFace model or local path) +model_path: "Qwen/Qwen2.5-0.5B" # Example: change to your model + +# Experiment identification +experiment_id: "ray_f2llm_training" + +# Output directories +output_dir: "./outputs/ray_train" +tb_dir: "./tensorboard/ray_train" +cache_dir: "./cache" + +# Training data +train_data_path: "./data/train" + +# ============================================================================= +# Training Hyperparameters +# ============================================================================= +# Batch size per device +train_batch_size: 8 + +# Maximum sequence length +max_seq_length: 2048 + +# Optimizer settings +learning_rate: 1.0e-4 +min_lr: 1.0e-6 +weight_decay: 0.01 + +# Learning rate schedule +warmup_steps: 100 + +# Embedding training settings +num_hard_neg: 7 # Number of hard negatives per sample + +# Training duration +# train_steps: -1 means use train_epochs instead +train_steps: -1 +train_epochs: 5 + +# Logging and checkpointing +log_interval: 20 # Log every N steps +checkpointing_steps: 100 # Save checkpoint every N steps +validation_steps: 100 # Run validation every N steps + +# ============================================================================= +# Notes and Tips +# ============================================================================= +# 1. For single-node training: +# - Keep ray_address: "auto" +# - Set num_workers to number of GPUs +# +# 2. For multi-node training: +# - Start Ray cluster on head node: ray start --head --port=6379 +# - Start Ray on worker nodes: ray start --address=:6379 +# - Set ray_address: "ray://:10001" +# - Set num_workers to total GPUs across all nodes +# +# 3. For fault-tolerant training (spot instances): +# - Increase max_failures (e.g., 5-10) +# - Decrease checkpointing_steps (e.g., 50) +# - Increase checkpoint_num_to_keep (e.g., 5) +# +# 4. DeepSpeed ZeRO stages: +# - ZeRO-1: Optimizer state partitioning +# - ZeRO-2: + Gradient partitioning (recommended) +# - ZeRO-3: + Parameter partitioning (for very large models) +# +# 5. For debugging: +# - Set num_workers: 1 +# - Use local_mode: ray.init(local_mode=True) in code diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..534a1f7 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -17,8 +17,21 @@ def __init__(self, self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.max_seq_length = max_seq_length - def set_device(self): - self.device = self.lm.device + def set_device(self, device=None): + """ + Set device - compatible with both Accelerate and Ray Train + + Args: + device: Specific device to use. If None, auto-detect from model. + """ + if device is not None: + self.device = device + elif hasattr(self.lm, 'device'): + # Accelerate path: model already has device attribute + self.device = self.lm.device + else: + # Ray Train path: get device from model parameters + self.device = next(self.lm.parameters()).device def forward(self, batch): bs = batch['bs'] diff --git a/F2LLM/ray_config.py b/F2LLM/ray_config.py new file mode 100644 index 0000000..2f7dd0b --- /dev/null +++ b/F2LLM/ray_config.py @@ -0,0 +1,229 @@ +""" +Ray Train Configuration Module + +This module defines the configuration dataclass for Ray distributed training, +extending the base Args class with Ray-specific parameters. +""" + +from dataclasses import dataclass, field +from typing import Optional, Dict, Any +from arguments import Args +import yaml +import json + + +@dataclass +class RayTrainConfig(Args): + """ + Ray Train configuration extending base Args + + Adds Ray-specific settings while maintaining compatibility with + existing training configuration. + """ + + # Ray cluster settings + ray_address: str = "auto" # "auto" for local, "ray://head:10001" for remote + num_workers: int = 8 # Number of training workers (typically = num GPUs) + use_gpu: bool = True # Whether to use GPU for training + + # Resource allocation per worker + resources_per_worker: Optional[Dict[str, float]] = None # {"CPU": 4, "GPU": 1} + + # Fault tolerance settings + enable_fault_tolerance: bool = True + max_retries: int = 3 # Maximum number of failure retries + + # DeepSpeed settings (preserved from Accelerate) + use_deepspeed: bool = True + zero_stage: int = 2 # ZeRO optimization stage (1, 2, or 3) + + # Checkpoint settings + checkpoint_num_to_keep: int = 3 + checkpoint_score_attribute: str = "loss" + checkpoint_score_order: str = "min" # "min" or "max" + + # Communication backend + backend: str = "nccl" # "nccl" for GPU, "gloo" for CPU + + def __post_init__(self): + """Post-initialization processing""" + super().__post_init__() if hasattr(super(), '__post_init__') else None + + # Set default resources per worker if not specified + if self.resources_per_worker is None: + self.resources_per_worker = { + "CPU": 4, + "GPU": 1 if self.use_gpu else 0 + } + + @classmethod + def from_yaml(cls, yaml_path: str) -> 'RayTrainConfig': + """ + Load configuration from YAML file + + Args: + yaml_path: Path to YAML configuration file + + Returns: + RayTrainConfig instance + """ + with open(yaml_path, 'r') as f: + config = yaml.safe_load(f) + + # Merge configuration sections + merged_config = {} + + # Merge cluster, scaling, training, fault_tolerance sections + for section in ['cluster', 'scaling', 'training', 'fault_tolerance']: + if section in config: + section_data = config[section] + if isinstance(section_data, dict): + merged_config.update(section_data) + + # Handle DeepSpeed configuration + if 'deepspeed_config' in merged_config: + ds_config = merged_config.pop('deepspeed_config') + if isinstance(ds_config, dict): + # Extract ZeRO stage + zero_config = ds_config.get('zero_optimization', {}) + if isinstance(zero_config, dict): + merged_config['zero_stage'] = zero_config.get('stage', 2) + + # Add top-level parameters (model_path, experiment_id, etc.) + for key in ['model_path', 'experiment_id', 'output_dir', 'tb_dir', + 'cache_dir', 'train_data_path', 'train_batch_size', + 'max_seq_length', 'learning_rate', 'min_lr', 'weight_decay', + 'warmup_steps', 'num_hard_neg', 'train_steps', 'train_epochs', + 'log_interval', 'checkpointing_steps', 'validation_steps']: + if key in config: + merged_config[key] = config[key] + + return cls(**merged_config) + + @classmethod + def from_json(cls, json_path: str) -> 'RayTrainConfig': + """ + Load configuration from JSON file (for compatibility with existing configs) + + Args: + json_path: Path to JSON configuration file + + Returns: + RayTrainConfig instance + """ + with open(json_path, 'r') as f: + config = json.load(f) + + return cls(**config) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert configuration to dictionary + + Returns: + Dictionary representation of configuration + """ + return self.dict() + + def save_yaml(self, yaml_path: str): + """ + Save configuration to YAML file + + Args: + yaml_path: Path to save YAML configuration + """ + config_dict = self.to_dict() + + with open(yaml_path, 'w') as f: + yaml.dump(config_dict, f, default_flow_style=False) + + def save_json(self, json_path: str): + """ + Save configuration to JSON file + + Args: + json_path: Path to save JSON configuration + """ + config_dict = self.to_dict() + + with open(json_path, 'w') as f: + json.dump(config_dict, f, indent=2) + + +def create_ray_config_from_accelerate( + accelerate_yaml: str, + base_json: str, + output_yaml: str +) -> RayTrainConfig: + """ + Create Ray configuration from existing Accelerate config files + + This helper function converts Accelerate configuration to Ray format, + preserving all training parameters. + + Args: + accelerate_yaml: Path to accelerate_config.yaml + base_json: Path to config.json (base training config) + output_yaml: Path to save Ray configuration + + Returns: + RayTrainConfig instance + """ + # Load base training config + with open(base_json, 'r') as f: + base_config = json.load(f) + + # Load Accelerate config + with open(accelerate_yaml, 'r') as f: + acc_config = yaml.safe_load(f) + + # Map Accelerate settings to Ray settings + ray_config = { + **base_config, # Include all base training parameters + 'num_workers': acc_config.get('num_processes', 8), + 'use_gpu': not acc_config.get('use_cpu', False), + 'backend': 'nccl' if not acc_config.get('use_cpu', False) else 'gloo', + } + + # Map DeepSpeed settings if present + if 'deepspeed_config' in acc_config: + ds_config = acc_config['deepspeed_config'] + ray_config['use_deepspeed'] = True + ray_config['zero_stage'] = ds_config.get('zero_stage', 2) + + # Create RayTrainConfig instance + config = RayTrainConfig(**ray_config) + + # Save to YAML + config.save_yaml(output_yaml) + + return config + + +if __name__ == "__main__": + # Example usage + print("Ray Train Configuration Module") + print("=" * 60) + + # Example 1: Create config from scratch + config = RayTrainConfig( + model_path="/path/to/model", + experiment_id="ray_test", + output_dir="./outputs", + tb_dir="./tensorboard", + cache_dir="./cache", + train_data_path="./data", + num_workers=8, + use_gpu=True, + ) + + print("\nExample 1: Config created from scratch") + print(f" Workers: {config.num_workers}") + print(f" GPU: {config.use_gpu}") + print(f" DeepSpeed: {config.use_deepspeed} (ZeRO-{config.zero_stage})") + + # Example 2: Load from YAML + # config = RayTrainConfig.from_yaml("configs/ray_config.yaml") + # print("\nExample 2: Config loaded from YAML") + + print("\n" + "=" * 60) diff --git a/F2LLM/ray_train.py b/F2LLM/ray_train.py new file mode 100644 index 0000000..ea7fd70 --- /dev/null +++ b/F2LLM/ray_train.py @@ -0,0 +1,423 @@ +""" +Ray Train Integration for F2LLM + +This module provides Ray Train distributed training capabilities for the F2LLM +embedding model, maintaining compatibility with the existing Accelerate-based training. +""" + +from ray_config import RayTrainConfig +from utils import ( + DistributedContext, inbatch_loss, hard_loss, validate, + write_tensorboard, save_checkpoint, + CLASSIFICATION_DATASETS, RETRIEVAL_DATASETS, CLUSTERING_DATASETS +) +from transformers import ( + AutoTokenizer, + set_seed, + get_scheduler +) +import os +import json +import random +from datasets import load_dataset +from torch.utils.data import DataLoader +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.optim import AdamW +from torch.nn import CrossEntropyLoss +from model import F2LLM +from tqdm.auto import tqdm +from torch.utils.tensorboard import SummaryWriter + +import ray.train +from ray.train import ScalingConfig, RunConfig, CheckpointConfig, FailureConfig +from ray.train.torch import TorchTrainer, TorchConfig + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def _stack(input_ids, max_len): + """Stack and truncate input IDs""" + data = [ids[:max_len] for ids in input_ids] + lens = [len(x) for x in data] + tensor = torch.tensor(sum(data, [])) + return tensor.split(lens) + + +def train_func(config): + """ + Main training function executed on each Ray worker + + Args: + config: Training configuration dictionary + """ + # Convert config dict to RayTrainConfig object + args = RayTrainConfig(**config) + + # Initialize distributed context (Ray backend) + distributed_ctx = DistributedContext(backend='ray') + + # Set random seed for reproducibility + set_seed(0) + + # Create output directories (main process only) + if distributed_ctx.is_main_process(): + os.makedirs(args.output_dir, exist_ok=True) + with open(os.path.join(args.output_dir, "args.json"), "w") as f: + json.dump(args.dict(), f, indent=2) + + # Load datasets + distributed_ctx.print("Loading datasets...") + train_datasets, valid_datasets = [], [] + for f in sorted(os.listdir(args.train_data_path)): + if not f.endswith('.parquet'): + continue + dataset_name = f.split('.parquet')[0] + dataset_path = os.path.join(args.train_data_path, f) + dataset = load_dataset("parquet", data_files=dataset_path, cache_dir=args.cache_dir)['train'] + dataset = dataset.add_column("dataset_name", [dataset_name]*len(dataset)) + dataset = dataset.train_test_split(train_size=0.99, shuffle=True, seed=0) + train_datasets.append((dataset_name, dataset['train'])) + valid_datasets.append((dataset_name, dataset['test'])) + + distributed_ctx.print(f"Loaded {len(train_datasets)} datasets") + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + def collate_fn(batch_raw): + """Collate function for data loading""" + num_hard_neg = 1 if batch_raw[0]['dataset_name'] in CLASSIFICATION_DATASETS else args.num_hard_neg + hard_neg_indices = [0] if num_hard_neg == 1 else random.sample(list(range(24)), num_hard_neg) + + input_ids = _stack( + [s['query_input_ids'] for s in batch_raw] + + [s['passage_input_ids'] for s in batch_raw] + + [s[f'negative_{i+1}_input_ids'] for s in batch_raw for i in hard_neg_indices], + args.max_seq_length + ) + seqlens = torch.tensor([ids.size(0) for ids in input_ids]) + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + attention_masks = input_ids.ne(tokenizer.pad_token_id).long() + + return { + 'input_ids': input_ids, + 'seq_lens': seqlens, + 'attention_mask': attention_masks, + 'bs': len(batch_raw), + 'dataset_name': batch_raw[0]['dataset_name'] + } + + # Create data loaders + train_loaders = { + name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn) + for name, ds in train_datasets + } + valid_loaders = { + name: DataLoader(ds, shuffle=False, batch_size=args.train_batch_size, collate_fn=collate_fn) + for name, ds in valid_datasets + } + + # Prepare data loaders for Ray Train + from ray.train.torch import prepare_data_loader + train_loaders = { + name: prepare_data_loader(loader) + for name, loader in train_loaders.items() + } + valid_loaders = { + name: prepare_data_loader(loader) + for name, loader in valid_loaders.items() + } + + class MultiLoader: + """Multi-dataset loader with weighted sampling""" + def __init__(self, loader_dict): + self.loader_dict = loader_dict + + def __len__(self): + return sum(len(v) for v in self.loader_dict.values()) + + def reset_epoch(self, epoch): + self.rng = random.Random(epoch) + self.iters = {k: iter(v) for k, v in self.loader_dict.items()} + self.names = list(self.iters.keys()) + self.weights = [len(self.loader_dict[k]) for k in self.names] + + def __iter__(self): + while self.names: + name = self.rng.choices(self.names, weights=self.weights)[0] + try: + batch = next(self.iters[name]) + yield batch + except StopIteration: + idx = self.names.index(name) + self.names.pop(idx) + self.weights.pop(idx) + + # Determine training steps + override_train_step = False + if args.train_steps < 0: + args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs + override_train_step = True + + distributed_ctx.print(f"Training steps before prepare: {args.train_steps}") + + # Create model + distributed_ctx.print("Creating model...") + model = F2LLM(args.model_path, args.max_seq_length, args=args) + model.lm.gradient_checkpointing_enable() + + # Set seed again for consistent initialization + set_seed(0) + + # Create optimizer and scheduler + optimizer = AdamW( + model.lm.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.98) + ) + + lr_scheduler = get_scheduler( + "cosine", + optimizer=optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=args.train_steps + ) + + # Prepare model and optimizer for Ray Train + from ray.train.torch import prepare_model, prepare_optimizer + model.lm = prepare_model(model.lm) + optimizer = prepare_optimizer(optimizer) + + # Set model device + model.set_device() + + # Create MultiLoader + train_dataloader = MultiLoader(train_loaders) + + # Adjust training steps if needed + if override_train_step: + args.train_steps = len(train_dataloader) * args.train_epochs + distributed_ctx.print(f"Training steps after prepare: {args.train_steps}") + + # Start training + distributed_ctx.print("=" * 80) + distributed_ctx.print("Starting training") + distributed_ctx.print(f" Num train samples = {sum(len(ds) for _, ds in train_datasets)}") + distributed_ctx.print(f" Num epochs = {args.train_epochs}") + distributed_ctx.print(f" Per device batch size = {args.train_batch_size}") + distributed_ctx.print(f" Global batch size = {args.train_batch_size * distributed_ctx.world_size}") + distributed_ctx.print(f" Steps per epoch = {len(train_dataloader)}") + distributed_ctx.print(f" Total training steps = {args.train_steps}") + distributed_ctx.print("=" * 80) + + # Filter datasets + global RETRIEVAL_DATASETS, CLASSIFICATION_DATASETS, CLUSTERING_DATASETS + RETRIEVAL_DATASETS = [ds for ds in RETRIEVAL_DATASETS if ds in train_dataloader.loader_dict.keys()] + CLASSIFICATION_DATASETS = [ds for ds in CLASSIFICATION_DATASETS if ds in train_dataloader.loader_dict.keys()] + CLUSTERING_DATASETS = [ds for ds in CLUSTERING_DATASETS if ds in train_dataloader.loader_dict.keys()] + + # Initialize TensorBoard writer + summary_writer = SummaryWriter(log_dir=args.tb_dir) if distributed_ctx.is_main_process() else None + + # Training loop + criterion = CrossEntropyLoss(reduction='none') + pbar = tqdm(range(args.train_steps), disable=not distributed_ctx.is_local_main_process()) + completed_steps = 0 + + # Initialize loss tracking + loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} + count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + + model.lm.train() + + for epoch in range(args.train_epochs): + distributed_ctx.print(f"Starting epoch {epoch+1}") + train_dataloader.reset_epoch(epoch) + + for batch in train_dataloader: + # Forward pass + outputs = model.forward(batch) + + # Compute losses + loss_hard = hard_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], + criterion, + distributed_ctx + ) + + dataset_name = batch['dataset_name'] + count_hard_dict[dataset_name] += 1 + loss_hard_dict[dataset_name] += loss_hard.detach().float() + + if dataset_name in RETRIEVAL_DATASETS: + loss = inbatch_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + criterion, + distributed_ctx + ) + count_dict[dataset_name] += 1 + loss_dict[dataset_name] += loss.detach().float() + else: + loss = 0.0 + + loss_total = loss + loss_hard + + # Backward pass + loss_total.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Enforce minimum learning rate + if optimizer.param_groups[0]['lr'] < args.min_lr: + for i in range(len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = args.min_lr + + # Logging + completed_steps += 1 + if completed_steps % args.log_interval == 0: + pbar.update(args.log_interval) + + train_log_dict = {"lr": optimizer.param_groups[0]['lr']} + + # Aggregate losses across GPUs + for k in loss_dict.keys(): + count = distributed_ctx.gather(count_dict[k]).sum() + if count > 0: + train_log_dict[f"{k}/training_loss_in_batch"] = distributed_ctx.gather(loss_dict[k]).sum() / count + + for k in loss_hard_dict.keys(): + count = distributed_ctx.gather(count_hard_dict[k]).sum() + if count > 0: + train_log_dict[f"{k}/training_loss_hard"] = distributed_ctx.gather(loss_hard_dict[k]).sum() / count + + # Compute averages + train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() if any(k.endswith('training_loss_in_batch') for k in train_log_dict.keys()) else torch.tensor(0.0) + train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() if any(k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard') for k in train_log_dict.keys()) else torch.tensor(0.0) + train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() if any(k.split('/')[0] in CLASSIFICATION_DATASETS for k in train_log_dict.keys()) else torch.tensor(0.0) + train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() if any(k.split('/')[0] in CLUSTERING_DATASETS for k in train_log_dict.keys()) else torch.tensor(0.0) + + distributed_ctx.print(f"[Train] Step = {completed_steps}") + if distributed_ctx.is_main_process(): + write_tensorboard(summary_writer, train_log_dict, completed_steps) + + # Reset counters + loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} + count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + + # Validation + if completed_steps % args.validation_steps == 0: + model.lm.eval() + validate(args, distributed_ctx, model, valid_loaders, criterion, completed_steps, summary_writer) + model.lm.train() + + # Checkpoint saving + if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: + output_dir = os.path.join(args.output_dir, f"step_{completed_steps}") + save_checkpoint(args, distributed_ctx, model, output_dir, lr_scheduler) + + # Report checkpoint to Ray Train + if distributed_ctx.is_main_process(): + ray.train.report( + metrics={"step": completed_steps}, + checkpoint=ray.train.Checkpoint.from_directory(output_dir) + ) + + if completed_steps >= args.train_steps: + break + + # Epoch checkpoint + output_dir = os.path.join(args.output_dir, f"epoch_{epoch+1}") + save_checkpoint(args, distributed_ctx, model, output_dir, lr_scheduler) + + if completed_steps % args.validation_steps != 0: + model.lm.eval() + validate(args, distributed_ctx, model, valid_loaders, criterion, completed_steps, summary_writer) + model.lm.train() + + if summary_writer: + summary_writer.close() + + distributed_ctx.print("Training completed!") + + +class RayF2LLMTrainer: + """Ray Train trainer wrapper for F2LLM""" + + def __init__(self, config: RayTrainConfig): + """ + Initialize Ray trainer + + Args: + config: RayTrainConfig instance + """ + self.config = config + + # Create ScalingConfig + scaling_config = ScalingConfig( + num_workers=config.num_workers, + use_gpu=config.use_gpu, + resources_per_worker=config.resources_per_worker, + ) + + # Create TorchConfig + torch_config = TorchConfig( + backend=config.backend, + ) + + # Create CheckpointConfig + checkpoint_config = CheckpointConfig( + num_to_keep=config.checkpoint_num_to_keep, + checkpoint_score_attribute=config.checkpoint_score_attribute, + checkpoint_score_order=config.checkpoint_score_order, + ) + + # Create FailureConfig if fault tolerance enabled + failure_config = None + if config.enable_fault_tolerance: + failure_config = FailureConfig(max_failures=config.max_retries) + + # Create RunConfig + run_config = RunConfig( + name=config.experiment_id, + storage_path=os.path.dirname(config.output_dir), + checkpoint_config=checkpoint_config, + failure_config=failure_config, + ) + + # Create TorchTrainer + self.trainer = TorchTrainer( + train_loop_per_worker=train_func, + train_loop_config=config.dict(), + scaling_config=scaling_config, + torch_config=torch_config, + run_config=run_config, + ) + + def fit(self): + """Start training""" + result = self.trainer.fit() + return result + + +if __name__ == "__main__": + # Example usage for testing + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to ray_config.yaml") + args = parser.parse_args() + + config = RayTrainConfig.from_yaml(args.config) + trainer = RayF2LLMTrainer(config) + result = trainer.fit() + print(f"Training completed: {result}") diff --git a/F2LLM/requirements.txt b/F2LLM/requirements.txt index 82fb447..5cccbd9 100644 --- a/F2LLM/requirements.txt +++ b/F2LLM/requirements.txt @@ -3,5 +3,9 @@ datasets deepspeed flash-attn torch -transformers +transformers>=4.51.0 tensorboard + +# Ray distributed training (optional, for Ray Train support) +ray[train]>=2.30.0 +pyyaml>=6.0 diff --git a/F2LLM/tests/__init__.py b/F2LLM/tests/__init__.py new file mode 100644 index 0000000..8a9513e --- /dev/null +++ b/F2LLM/tests/__init__.py @@ -0,0 +1,6 @@ +""" +F2LLM Test Suite + +This package contains unit tests and integration tests for the F2LLM +distributed training system. +""" diff --git a/F2LLM/tests/test_distributed_context.py b/F2LLM/tests/test_distributed_context.py new file mode 100644 index 0000000..604c363 --- /dev/null +++ b/F2LLM/tests/test_distributed_context.py @@ -0,0 +1,276 @@ +""" +Unit tests for DistributedContext + +Tests the abstraction layer for both Accelerate and Ray Train backends. +""" + +import unittest +import sys +import os +from unittest.mock import Mock, patch, MagicMock + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +import torch +from utils import DistributedContext + + +class TestDistributedContextAccelerate(unittest.TestCase): + """Test DistributedContext with Accelerate backend""" + + @patch('utils.Accelerator') + def test_init_accelerate(self, MockAccelerator): + """Test Accelerate backend initialization""" + mock_acc = MockAccelerator.return_value + mock_acc.process_index = 0 + mock_acc.num_processes = 4 + mock_acc.local_process_index = 0 + + ctx = DistributedContext(backend='accelerate') + + self.assertEqual(ctx.backend, 'accelerate') + self.assertEqual(ctx.rank, 0) + self.assertEqual(ctx.world_size, 4) + self.assertEqual(ctx.local_rank, 0) + MockAccelerator.assert_called_once() + + @patch('utils.Accelerator') + def test_is_main_process(self, MockAccelerator): + """Test main process detection""" + mock_acc = MockAccelerator.return_value + mock_acc.process_index = 0 + mock_acc.num_processes = 4 + mock_acc.local_process_index = 0 + + ctx = DistributedContext(backend='accelerate') + self.assertTrue(ctx.is_main_process()) + + # Non-main process + mock_acc.process_index = 1 + ctx2 = DistributedContext(backend='accelerate') + self.assertFalse(ctx2.is_main_process()) + + @patch('utils.Accelerator') + def test_gather_accelerate(self, MockAccelerator): + """Test gather operation with Accelerate""" + mock_acc = MockAccelerator.return_value + mock_acc.process_index = 0 + mock_acc.num_processes = 1 + + tensor = torch.randn(4, 128) + mock_acc.gather.return_value = tensor + + ctx = DistributedContext(backend='accelerate') + result = ctx.gather(tensor) + + mock_acc.gather.assert_called_once_with(tensor) + torch.testing.assert_close(result, tensor) + + @patch('utils.Accelerator') + def test_wait_for_everyone(self, MockAccelerator): + """Test synchronization barrier""" + mock_acc = MockAccelerator.return_value + mock_acc.process_index = 0 + mock_acc.num_processes = 4 + + ctx = DistributedContext(backend='accelerate') + ctx.wait_for_everyone() + + mock_acc.wait_for_everyone.assert_called_once() + + +class TestDistributedContextRay(unittest.TestCase): + """Test DistributedContext with Ray backend""" + + @patch('ray.train.get_context') + def test_init_ray(self, mock_get_context): + """Test Ray backend initialization""" + mock_ctx = Mock() + mock_ctx.get_world_rank.return_value = 1 + mock_ctx.get_world_size.return_value = 8 + mock_ctx.get_local_rank.return_value = 1 + mock_get_context.return_value = mock_ctx + + ctx = DistributedContext(backend='ray') + + self.assertEqual(ctx.backend, 'ray') + self.assertEqual(ctx.rank, 1) + self.assertEqual(ctx.world_size, 8) + self.assertEqual(ctx.local_rank, 1) + + @patch('ray.train.get_context') + def test_is_main_process_ray(self, mock_get_context): + """Test main process detection with Ray""" + mock_ctx = Mock() + mock_ctx.get_world_rank.return_value = 0 + mock_ctx.get_world_size.return_value = 8 + mock_ctx.get_local_rank.return_value = 0 + mock_get_context.return_value = mock_ctx + + ctx = DistributedContext(backend='ray') + self.assertTrue(ctx.is_main_process()) + + # Non-main process + mock_ctx.get_world_rank.return_value = 1 + ctx2 = DistributedContext(backend='ray') + self.assertFalse(ctx2.is_main_process()) + + @patch('torch.distributed.is_initialized') + @patch('torch.distributed.all_gather') + @patch('ray.train.get_context') + def test_gather_ray(self, mock_get_context, mock_all_gather, mock_is_init): + """Test gather operation with Ray""" + mock_ctx = Mock() + mock_ctx.get_world_rank.return_value = 0 + mock_ctx.get_world_size.return_value = 2 + mock_ctx.get_local_rank.return_value = 0 + mock_get_context.return_value = mock_ctx + mock_is_init.return_value = True + + tensor = torch.randn(4, 128) + + # Mock all_gather behavior + def all_gather_side_effect(tensor_list, tensor): + for i, t in enumerate(tensor_list): + t.copy_(tensor) + mock_all_gather.side_effect = all_gather_side_effect + + ctx = DistributedContext(backend='ray') + result = ctx.gather(tensor) + + # Should concatenate tensors + self.assertEqual(result.shape[0], tensor.shape[0] * 2) + mock_all_gather.assert_called_once() + + @patch('torch.distributed.is_initialized') + @patch('ray.train.get_context') + def test_gather_ray_not_initialized(self, mock_get_context, mock_is_init): + """Test gather when distributed not initialized""" + mock_ctx = Mock() + mock_ctx.get_world_rank.return_value = 0 + mock_ctx.get_world_size.return_value = 1 + mock_ctx.get_local_rank.return_value = 0 + mock_get_context.return_value = mock_ctx + mock_is_init.return_value = False + + tensor = torch.randn(4, 128) + + ctx = DistributedContext(backend='ray') + result = ctx.gather(tensor) + + # Should return original tensor when not initialized + torch.testing.assert_close(result, tensor) + + +class TestDistributedContextAutoDetect(unittest.TestCase): + """Test auto-detection of backend""" + + @patch('ray.train.get_context') + @patch('utils.Accelerator') + def test_auto_detect_ray(self, MockAccelerator, mock_get_context): + """Test auto-detection chooses Ray when available""" + mock_ctx = Mock() + mock_ctx.get_world_rank.return_value = 0 + mock_ctx.get_world_size.return_value = 4 + mock_ctx.get_local_rank.return_value = 0 + mock_get_context.return_value = mock_ctx + + ctx = DistributedContext(backend='auto') + + self.assertEqual(ctx.backend, 'ray') + MockAccelerator.assert_not_called() + + @patch('ray.train.get_context', side_effect=ImportError) + @patch('utils.Accelerator') + def test_auto_detect_accelerate(self, MockAccelerator, mock_get_context): + """Test auto-detection falls back to Accelerate""" + mock_acc = MockAccelerator.return_value + mock_acc.process_index = 0 + mock_acc.num_processes = 4 + mock_acc.local_process_index = 0 + + ctx = DistributedContext(backend='auto') + + self.assertEqual(ctx.backend, 'accelerate') + MockAccelerator.assert_called_once() + + +class TestDistributedContextHelpers(unittest.TestCase): + """Test helper methods""" + + @patch('utils.Accelerator') + def test_prepare_accelerate(self, MockAccelerator): + """Test prepare method with Accelerate""" + mock_acc = MockAccelerator.return_value + mock_acc.process_index = 0 + mock_acc.num_processes = 1 + + model = Mock() + optimizer = Mock() + mock_acc.prepare.return_value = (model, optimizer) + + ctx = DistributedContext(backend='accelerate') + result = ctx.prepare(model, optimizer) + + mock_acc.prepare.assert_called_once_with(model, optimizer) + self.assertEqual(len(result), 2) + + @patch('ray.train.get_context') + def test_prepare_ray(self, mock_get_context): + """Test prepare method with Ray (no-op)""" + mock_ctx = Mock() + mock_ctx.get_world_rank.return_value = 0 + mock_ctx.get_world_size.return_value = 1 + mock_ctx.get_local_rank.return_value = 0 + mock_get_context.return_value = mock_ctx + + model = Mock() + optimizer = Mock() + + ctx = DistributedContext(backend='ray') + result = ctx.prepare(model, optimizer) + + # Ray doesn't transform objects + self.assertEqual(result, (model, optimizer)) + + @patch('utils.Accelerator') + def test_unwrap_model_accelerate(self, MockAccelerator): + """Test unwrap_model with Accelerate""" + mock_acc = MockAccelerator.return_value + mock_acc.process_index = 0 + mock_acc.num_processes = 1 + + model = Mock() + unwrapped = Mock() + mock_acc.unwrap_model.return_value = unwrapped + + ctx = DistributedContext(backend='accelerate') + result = ctx.unwrap_model(model) + + mock_acc.unwrap_model.assert_called_once_with(model) + self.assertEqual(result, unwrapped) + + @patch('ray.train.get_context') + def test_unwrap_model_ray_with_module(self, mock_get_context): + """Test unwrap_model with Ray (DDP wrapped)""" + mock_ctx = Mock() + mock_ctx.get_world_rank.return_value = 0 + mock_ctx.get_world_size.return_value = 1 + mock_ctx.get_local_rank.return_value = 0 + mock_get_context.return_value = mock_ctx + + # Simulate DDP-wrapped model + inner_model = Mock() + wrapped_model = Mock() + wrapped_model.module = inner_model + + ctx = DistributedContext(backend='ray') + result = ctx.unwrap_model(wrapped_model) + + self.assertEqual(result, inner_model) + + +if __name__ == '__main__': + # Run tests + unittest.main(verbosity=2) diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..cbc41b7 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -3,8 +3,159 @@ import torch import torch.nn.functional as F from torch.nn import CrossEntropyLoss +import torch.distributed as dist import os + +class DistributedContext: + """ + Distributed training context abstraction layer + + Provides a unified API for both Accelerate and Ray Train backends, + enabling seamless switching between frameworks without code changes. + """ + + def __init__(self, backend='auto'): + """ + Initialize distributed context + + Args: + backend: 'auto' (auto-detect), 'accelerate', or 'ray' + """ + self.backend = backend + + if backend == 'auto': + self.backend = self._detect_backend() + + if self.backend == 'ray': + self._init_ray() + elif self.backend == 'accelerate': + self._init_accelerate() + else: + raise ValueError(f"Unknown backend: {backend}") + + def _detect_backend(self): + """Auto-detect which backend is being used""" + # Try Ray Train first + try: + import ray.train + ctx = ray.train.get_context() + if ctx is not None: + return 'ray' + except (ImportError, RuntimeError): + pass + + # Fall back to Accelerate + try: + from accelerate import Accelerator + return 'accelerate' + except ImportError: + raise RuntimeError("Neither Ray Train nor Accelerate is available") + + def _init_ray(self): + """Initialize Ray Train backend""" + import ray.train + self.ctx = ray.train.get_context() + self.rank = self.ctx.get_world_rank() + self.world_size = self.ctx.get_world_size() + self.local_rank = self.ctx.get_local_rank() + + def _init_accelerate(self): + """Initialize Accelerate backend""" + from accelerate import Accelerator + self.accelerator = Accelerator() + self.rank = self.accelerator.process_index + self.world_size = self.accelerator.num_processes + self.local_rank = self.accelerator.local_process_index + + def gather(self, tensor): + """ + Gather tensors from all processes + + Args: + tensor: Tensor to gather [local_bs, ...] + + Returns: + Gathered tensor [world_size * local_bs, ...] + """ + if self.backend == 'ray': + return self._ray_gather(tensor) + else: + return self.accelerator.gather(tensor) + + def _ray_gather(self, tensor): + """Ray backend gather implementation using PyTorch distributed""" + if not dist.is_initialized(): + return tensor + + # Create list to hold gathered tensors + gathered_tensors = [torch.zeros_like(tensor) for _ in range(self.world_size)] + + # Gather tensors from all processes + dist.all_gather(gathered_tensors, tensor) + + # Concatenate along batch dimension + return torch.cat(gathered_tensors, dim=0) + + def print(self, *args, **kwargs): + """Print only on main process""" + if self.is_main_process(): + print(*args, **kwargs) + + def is_main_process(self): + """Check if current process is main (rank 0)""" + return self.rank == 0 + + def is_local_main_process(self): + """Check if current process is local main""" + return self.local_rank == 0 + + def wait_for_everyone(self): + """Synchronization barrier across all processes""" + if self.backend == 'ray': + if dist.is_initialized(): + dist.barrier() + else: + self.accelerator.wait_for_everyone() + + def prepare(self, *args): + """ + Prepare models, optimizers, dataloaders + + For Accelerate: wraps with DDP/FSDP + For Ray: no-op (Ray Train handles this separately) + """ + if self.backend == 'accelerate': + return self.accelerator.prepare(*args) + else: + # Ray Train handles preparation separately + return args if len(args) > 1 else args[0] + + def unwrap_model(self, model): + """Unwrap model from DDP/FSDP wrapper""" + if self.backend == 'accelerate': + return self.accelerator.unwrap_model(model) + else: + # Ray Train uses DDP, unwrap if needed + if hasattr(model, 'module'): + return model.module + return model + + def save(self, obj, f): + """Save object to file""" + if self.backend == 'accelerate': + self.accelerator.save(obj, f) + else: + torch.save(obj, f) + + def get_state_dict(self, model): + """Get model state dictionary""" + if self.backend == 'accelerate': + return self.accelerator.get_state_dict(model) + else: + return self.unwrap_model(model).state_dict() + + CLASSIFICATION_DATASETS = ['amazon_counterfactual', 'amazon_polarity', 'imdb', 'toxic_conversations', 'cola'] CLUSTERING_DATASETS = ['amazon_reviews', 'banking77', 'emotion', 'mtop_intent', 'mtop_domain', 'massive_scenario', 'massive_intent', 'tweet_sentiment_extraction', 'arxiv_clustering_p2p', 'arxiv_clustering_s2s', 'biorxiv_clustering_p2p', 'biorxiv_clustering_s2s', 'medrxiv_clustering_p2p', 'medrxiv_clustering_s2s', 'reddit_clustering_p2p', 'reddit_clustering_s2s', 'stackexchange_clustering_p2p', 'stackexchange_clustering_s2s', 'twentynewsgroups'] RETRIEVAL_DATASETS = ['arguana', 'snli', 'mnli', 'anli', 'paq', 'squad', 'stackexchange', 'msmarco', 'natural_questions', 'hotpotqa', 'fever', 'eli5', 'fiqa', 'bioasq', 'nfcorpus', 'miracl', 'mrtidy', 'scifact', 'qqp', 'stackoverflowdupquestions', 'sts12', 'sts22', 'stsbenchmark', 'amazon_qa', 'cnn_dm', 'coliee', 'paq_part2', 'pubmedqa', 's2orc_abstract_citation', 's2orc_title_abstract', 's2orc_title_citation', 'sentence_compression', 'specter', 'triviaqa', 'xsum', 'stackexchange_part2', 'stackexchangedupquestions_s2s', 'stackexchangedupquestions_p2p'] diff --git a/scripts/run_ray_train.py b/scripts/run_ray_train.py new file mode 100755 index 0000000..95904b5 --- /dev/null +++ b/scripts/run_ray_train.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +""" +Ray Train Launcher Script for F2LLM + +This script initializes Ray and starts distributed training using Ray Train. + +Usage: + # Local single-node training + python scripts/run_ray_train.py --config F2LLM/configs/ray_config.yaml + + # Connect to remote Ray cluster + python scripts/run_ray_train.py --config F2LLM/configs/ray_config.yaml --ray-address ray://head:10001 + + # Override configuration parameters + python scripts/run_ray_train.py --config F2LLM/configs/ray_config.yaml --num-workers 16 +""" + +import argparse +import sys +import os + +# Add F2LLM to Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'F2LLM')) + +import ray +from ray_train import RayF2LLMTrainer +from ray_config import RayTrainConfig + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="Ray Train launcher for F2LLM distributed training", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Local training with 8 GPUs + python scripts/run_ray_train.py --config F2LLM/configs/ray_config.yaml + + # Multi-node training (connect to existing cluster) + python scripts/run_ray_train.py --config F2LLM/configs/ray_config.yaml \\ + --ray-address ray://10.0.0.1:10001 --num-workers 16 + + # Override specific parameters + python scripts/run_ray_train.py --config F2LLM/configs/ray_config.yaml \\ + --experiment-id my_experiment --train-epochs 10 + """ + ) + + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to Ray configuration file (YAML)" + ) + + parser.add_argument( + "--ray-address", + type=str, + default=None, + help="Ray cluster address (e.g., 'ray://head:10001'). If not specified, uses value from config." + ) + + # Configuration overrides + parser.add_argument("--num-workers", type=int, help="Override number of workers") + parser.add_argument("--experiment-id", type=str, help="Override experiment ID") + parser.add_argument("--train-epochs", type=int, help="Override training epochs") + parser.add_argument("--train-steps", type=int, help="Override training steps") + parser.add_argument("--learning-rate", type=float, help="Override learning rate") + parser.add_argument("--model-path", type=str, help="Override model path") + + # Debugging options + parser.add_argument( + "--local-mode", + action="store_true", + help="Run Ray in local mode for debugging (single process)" + ) + + parser.add_argument( + "--no-gpu", + action="store_true", + help="Disable GPU usage (CPU training)" + ) + + return parser.parse_args() + + +def main(): + """Main entry point""" + args = parse_args() + + print("=" * 80) + print("Ray Train Launcher for F2LLM") + print("=" * 80) + + # Load configuration + print(f"\nLoading configuration from: {args.config}") + config = RayTrainConfig.from_yaml(args.config) + + # Apply command-line overrides + if args.ray_address: + config.ray_address = args.ray_address + if args.num_workers: + config.num_workers = args.num_workers + if args.experiment_id: + config.experiment_id = args.experiment_id + if args.train_epochs: + config.train_epochs = args.train_epochs + if args.train_steps: + config.train_steps = args.train_steps + if args.learning_rate: + config.learning_rate = args.learning_rate + if args.model_path: + config.model_path = args.model_path + if args.no_gpu: + config.use_gpu = False + + # Display configuration + print("\nTraining Configuration:") + print(f" Experiment ID: {config.experiment_id}") + print(f" Model: {config.model_path}") + print(f" Workers: {config.num_workers}") + print(f" GPUs: {'Yes' if config.use_gpu else 'No'}") + print(f" Epochs: {config.train_epochs}") + print(f" Batch size: {config.train_batch_size}") + print(f" Learning rate: {config.learning_rate}") + print(f" Output dir: {config.output_dir}") + + # Initialize Ray + print("\nInitializing Ray...") + ray_address = config.ray_address + + if args.local_mode: + print(" Mode: Local (debugging)") + ray.init(local_mode=True) + elif ray_address == "auto": + print(" Mode: Auto (local cluster)") + ray.init() + else: + print(f" Mode: Remote cluster at {ray_address}") + ray.init(address=ray_address) + + # Display cluster information + try: + dashboard_url = ray.get_runtime_context().get_dashboard_url() + print(f" Dashboard: {dashboard_url}") + except Exception: + print(" Dashboard: Not available") + + resources = ray.available_resources() + print(f" Available CPUs: {resources.get('CPU', 0):.0f}") + print(f" Available GPUs: {resources.get('GPU', 0):.0f}") + + if config.use_gpu and resources.get('GPU', 0) == 0: + print("\n⚠️ WARNING: GPU training requested but no GPUs available!") + response = input("Continue with CPU training? (y/n): ") + if response.lower() != 'y': + print("Aborting.") + ray.shutdown() + return + + # Create trainer + print("\nCreating Ray trainer...") + trainer = RayF2LLMTrainer(config) + + # Start training + print("\n" + "=" * 80) + print(f"Starting training: {config.experiment_id}") + print("=" * 80) + print() + + try: + result = trainer.fit() + + print("\n" + "=" * 80) + print("Training completed successfully!") + print("=" * 80) + print(f"\nResults: {result}") + print(f"\nCheckpoints saved to: {config.output_dir}") + print(f"TensorBoard logs: {config.tb_dir}") + print("\nTo view training metrics:") + print(f" tensorboard --logdir {config.tb_dir}") + + except Exception as e: + print("\n" + "=" * 80) + print("Training failed!") + print("=" * 80) + print(f"\nError: {str(e)}") + import traceback + traceback.print_exc() + return 1 + + finally: + # Shutdown Ray + print("\nShutting down Ray...") + ray.shutdown() + + return 0 + + +if __name__ == "__main__": + sys.exit(main())