diff --git a/src/mini_trainer/api_train.py b/src/mini_trainer/api_train.py index 74a654b..5220499 100644 --- a/src/mini_trainer/api_train.py +++ b/src/mini_trainer/api_train.py @@ -180,7 +180,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.save_dtype: command.append(f"--save-dtype={train_args.save_dtype}") - + + command.append(f"--device={train_args.device}") + if train_args.torch_compile: + command.append("--torch-compile") + command.append(f"--num-chunks={train_args.num_chunks}") + logger.info("Running training command as subprocess: %s", " ".join(command)) # Run the training process diff --git a/src/mini_trainer/osft_utils.py b/src/mini_trainer/osft_utils.py index cdc92e1..7aa4a61 100644 --- a/src/mini_trainer/osft_utils.py +++ b/src/mini_trainer/osft_utils.py @@ -5,12 +5,12 @@ import torch.distributed as dist import typing as t from typing import Protocol +import multiprocessing from tqdm import tqdm from mini_trainer.utils import log_rank_0, check_distributed_is_synchronized from mini_trainer.gpt_oss_utils import is_gpt_oss_model -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM import os @@ -291,6 +291,11 @@ def cast_to_osft_model(model: torch.nn.Module) -> OSFTModel: raise TypeError(f"Model {type(model)} does not implement OSFT interface") return model # type: ignore + +def create_svd_dict_star(args_tuple): + return create_svd_dict(*args_tuple) + + def create_svd_dict( weight: torch.Tensor, top_k: int, @@ -503,7 +508,7 @@ def partition_svd_computation(target_params, world_size): return assignments -def _broadcast_tensor_device_aware(tensor, src_rank): +def _broadcast_tensor_device_aware(tensor, src_rank, device): """ Broadcasts a tensor that might be on CPU by temporarily moving to GPU if needed. @@ -514,9 +519,7 @@ def _broadcast_tensor_device_aware(tensor, src_rank): # If tensor is on CPU, we need to move to GPU for NCCL broadcasting if tensor.device.type == 'cpu': current_rank = dist.get_rank() - local_rank = int(os.environ.get('LOCAL_RANK', 0)) - gpu_device = torch.device('cuda', local_rank) - tensor_gpu = tensor.to(gpu_device) + tensor_gpu = tensor.to(device) # Broadcast the GPU tensor dist.broadcast(tensor_gpu, src=src_rank) @@ -547,6 +550,7 @@ def broadcast_svd_results(model, assignments, world_size): param_to_rank[name] = rank # Broadcast each parameter from its computing rank + device = next(model.parameters()).device for name in param_to_rank: src_rank = param_to_rank[name] # Locate the owning module for this parameter @@ -558,7 +562,7 @@ def broadcast_svd_results(model, assignments, world_size): for tensor_name in ("osft_U_high", "osft_S_high", "osft_V_high"): if hasattr(owner_mod, tensor_name): tensor = getattr(owner_mod, tensor_name) - _broadcast_tensor_device_aware(tensor, src_rank) + _broadcast_tensor_device_aware(tensor, src_rank, device) else: raise AttributeError(f"OSFT broadcast: {tensor_name} not found on module for {name}") @@ -568,9 +572,9 @@ def broadcast_svd_results(model, assignments, world_size): # Broadcast trainable low-rank components (on module.osft_params) if hasattr(owner_mod, "osft_params"): svd_module = owner_mod.osft_params - _broadcast_tensor_device_aware(svd_module.U_low, src_rank) - _broadcast_tensor_device_aware(svd_module.S_low, src_rank) - _broadcast_tensor_device_aware(svd_module.V_low, src_rank) + _broadcast_tensor_device_aware(svd_module.U_low, src_rank, device) + _broadcast_tensor_device_aware(svd_module.S_low, src_rank, device) + _broadcast_tensor_device_aware(svd_module.V_low, src_rank, device) else: raise AttributeError(f"OSFT broadcast: osft_params not found on module for {name}") @@ -857,7 +861,7 @@ def _initialize_osft_with_distribution(model): Returns: The initialized model """ - + if not dist.is_initialized() or dist.get_world_size() == 1: # Simple cases: non-distributed or single process model.reinitialize_osft(decompose_existing_weights=True) @@ -965,6 +969,7 @@ def from_pretrained( if is_gpt_oss: base_kwargs = _filter_osft_parameters(filtered_kwargs, OSFT_GPT_OSS_FILTERED_PARAMS) # For GPT-OSS, we need to use the specific model class + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM actual_osft_cls = create_osft_model_class(GptOssForCausalLM) else: base_kwargs = filtered_kwargs.copy() @@ -1110,7 +1115,7 @@ def forward(x): mod.forward = make_forward(mod, bias) param.requires_grad = False mod._parameters.pop(attr, None) - + # Step 4: Each rank computes SVD for its assigned parameters assigned_params = assignments[global_rank] if assigned_params: @@ -1122,7 +1127,8 @@ def forward(x): # It's possible for processes to be dysynchronized by this point due to the # uneven split of work when processing model layers in parallel. # So here we ensure that everything is synchronized before proceeding. - check_distributed_is_synchronized() + device = next(self.parameters()).device + check_distributed_is_synchronized(device) broadcast_svd_results(self, assignments, world_size) torch.distributed.barrier() @@ -1142,6 +1148,48 @@ def _get_module_by_name(self, name): return None, None return mod, attr + def _save_svd_dict_to_module(self, name, param, svd_dict): + safe_name = name.replace(".", "_") + self.name_mapping[name] = safe_name + + # Attach OSFT components to the owning module so only block-local params materialize + mod, attr = self._get_module_by_name(name) + # High-rank frozen components + mod.register_parameter("osft_U_high", nn.Parameter(svd_dict["U_high"], requires_grad=False)) + mod.register_parameter("osft_S_high", nn.Parameter(svd_dict["S_high"], requires_grad=False)) + mod.register_parameter("osft_V_high", nn.Parameter(svd_dict["V_high"], requires_grad=False)) + # Trainable low-rank components + module_svd = nn.Module() + module_svd.U_low = svd_dict["U_low"] + module_svd.S_low = svd_dict["S_low"] + module_svd.V_low = svd_dict["V_low"] + module_svd.rank_high = svd_dict["rank_high"] + module_svd.safe_name = safe_name + mod.add_module("osft_params", module_svd) + + bias = mod.bias if hasattr(mod, "bias") else None + + # Override linear projection to use module-local OSFT params + def make_forward(owner_mod, bias): + def forward(x): + svd_dict = { + "U_high": owner_mod.osft_U_high, + "S_high": owner_mod.osft_S_high, + "V_high": owner_mod.osft_V_high, + "U_low": owner_mod.osft_params.U_low, + "S_low": owner_mod.osft_params.S_low, + "V_low": owner_mod.osft_params.V_low, + "rank_high": owner_mod.osft_params.rank_high, + } + return self._factorized_linear(x, svd_dict, bias) + return forward + + mod.forward = make_forward(mod, bias) + param.requires_grad = False + # Remove original parameter so it doesn't get updated + mod._parameters.pop(attr, None) + torch.cuda.empty_cache() + def _initialize_osft_parameters( self, decompose_existing_weights: bool, assigned_params=None ): @@ -1187,97 +1235,76 @@ def _initialize_osft_parameters( log_rank_0("\033[33m!!!! Initializing OSFT Params!!!!\033[0m") # Set up target device for memory-efficient operations - local_rank = int(os.getenv("LOCAL_RANK", 0)) - target_device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") + target_device = next(self.parameters()).device if self.osft_memory_efficient_init: log_rank_0(f"🧠 Using ultra-memory-efficient OSFT initialization for device {target_device}") - for name, param in named_params: - # Apply SVD only to 2D matrices in the target config (e.g., q_proj, down_proj, etc.) - if is_osft_param(name, param, self.osft_config): - top_k = self.osft_config[name] - - if self.osft_memory_efficient_init: - # Memory monitoring before processing - if torch.cuda.is_available(): - mem_before = torch.cuda.memory_allocated(target_device) / 1e9 - log_rank_0(f"🔄 Processing {name} with incremental GPU usage (top_k={top_k}) - GPU mem: {mem_before:.2f}GB") + if target_device.type != "hpu": + for name, param in named_params: + # Apply SVD only to 2D matrices in the target config (e.g., q_proj, down_proj, etc.) + if is_osft_param(name, param, self.osft_config): + top_k = self.osft_config[name] - # Memory-efficient processing: move parameter to GPU temporarily for SVD - param_gpu = param.data.to(target_device) - - # Perform SVD on GPU - svd_dict = create_svd_dict( - param_gpu, - top_k=top_k, - decompose_existing=decompose_existing_weights, - upcast_dtype=self.upcast_dtype, - output_dtype=self.output_dtype, - ) - - # Move SVD components to target device and clear GPU cache - for key in svd_dict: - if isinstance(svd_dict[key], torch.Tensor): - svd_dict[key] = svd_dict[key].to(target_device) - - # Clear the temporary GPU tensor - del param_gpu - torch.cuda.empty_cache() - - # Memory monitoring after processing - if torch.cuda.is_available(): - mem_after = torch.cuda.memory_allocated(target_device) / 1e9 - log_rank_0(f"✅ Completed {name} - GPU mem: {mem_after:.2f}GB (freed: {mem_before-mem_after:.2f}GB)") - else: - # Standard processing - svd_dict = create_svd_dict( - param.data, - top_k=top_k, - decompose_existing=decompose_existing_weights, - upcast_dtype=self.upcast_dtype, - output_dtype=self.output_dtype, - ) - safe_name = name.replace(".", "_") - self.name_mapping[name] = safe_name - - # Attach OSFT components to the owning module so only block-local params materialize - mod, attr = self._get_module_by_name(name) - # High-rank frozen components - mod.register_parameter("osft_U_high", nn.Parameter(svd_dict["U_high"], requires_grad=False)) - mod.register_parameter("osft_S_high", nn.Parameter(svd_dict["S_high"], requires_grad=False)) - mod.register_parameter("osft_V_high", nn.Parameter(svd_dict["V_high"], requires_grad=False)) - # Trainable low-rank components - module_svd = nn.Module() - module_svd.U_low = svd_dict["U_low"] - module_svd.S_low = svd_dict["S_low"] - module_svd.V_low = svd_dict["V_low"] - module_svd.rank_high = svd_dict["rank_high"] - module_svd.safe_name = safe_name - mod.add_module("osft_params", module_svd) - - bias = mod.bias if hasattr(mod, "bias") else None - - # Override linear projection to use module-local OSFT params - def make_forward(owner_mod, bias): - def forward(x): - svd_dict = { - "U_high": owner_mod.osft_U_high, - "S_high": owner_mod.osft_S_high, - "V_high": owner_mod.osft_V_high, - "U_low": owner_mod.osft_params.U_low, - "S_low": owner_mod.osft_params.S_low, - "V_low": owner_mod.osft_params.V_low, - "rank_high": owner_mod.osft_params.rank_high, - } - return self._factorized_linear(x, svd_dict, bias) - return forward - - mod.forward = make_forward(mod, bias) - param.requires_grad = False - # Remove original parameter so it doesn't get updated - mod._parameters.pop(attr, None) - torch.cuda.empty_cache() + if self.osft_memory_efficient_init: + # Memory monitoring before processing + if torch.cuda.is_available(): + mem_before = torch.cuda.memory_allocated(target_device) / 1e9 + log_rank_0(f"🔄 Processing {name} with incremental GPU usage (top_k={top_k}) - GPU mem: {mem_before:.2f}GB") + + # Memory-efficient processing: move parameter to GPU temporarily for SVD + param_gpu = param.data.to(target_device) + + # Perform SVD on GPU + svd_dict = create_svd_dict( + param_gpu, + top_k=top_k, + decompose_existing=decompose_existing_weights, + upcast_dtype=self.upcast_dtype, + output_dtype=self.output_dtype, + ) + + # Move SVD components to target device and clear GPU cache + for key in svd_dict: + if isinstance(svd_dict[key], torch.Tensor): + svd_dict[key] = svd_dict[key].to(target_device) + + # Clear the temporary GPU tensor + del param_gpu + torch.cuda.empty_cache() + + # Memory monitoring after processing + if torch.cuda.is_available(): + mem_after = torch.cuda.memory_allocated(target_device) / 1e9 + log_rank_0(f"✅ Completed {name} - GPU mem: {mem_after:.2f}GB (freed: {mem_before-mem_after:.2f}GB)") + else: + # Standard processing + svd_dict = create_svd_dict( + param.data, + top_k=top_k, + decompose_existing=decompose_existing_weights, + upcast_dtype=self.upcast_dtype, + output_dtype=self.output_dtype, + ) + self._save_svd_dict_to_module(name, param, svd_dict) + + else: + osft_params_to_process = [] + for name, param in named_params: + if is_osft_param(name, param, self.osft_config): + osft_params_to_process.append((name, param, param.data, self.osft_config[name], decompose_existing_weights, self.upcast_dtype, self.output_dtype)) + + svd_args = [(param_data.to('cpu'), top_k, decompose_existing, upcast_dtype, output_dtype) + for _, _, param_data, top_k, decompose_existing, upcast_dtype, output_dtype in osft_params_to_process] + + os.environ['PT_HPU_AUTOLOAD'] = '0' + num_cpu_per_rank = 8 # TODO add command line parameter + mp_context = multiprocessing.get_context('spawn') + with mp_context.Pool(num_cpu_per_rank) as pool: + res = pool.imap(create_svd_dict_star, svd_args) + for i, svd_dict in enumerate(res): + name, param, _, _, _, _, _ = osft_params_to_process[i] + self._save_svd_dict_to_module(name, param, svd_dict) # Barrier for synchronization in distributed setting if dist.is_initialized(): diff --git a/src/mini_trainer/setup_model_for_training.py b/src/mini_trainer/setup_model_for_training.py index 88994d7..ffe59ae 100644 --- a/src/mini_trainer/setup_model_for_training.py +++ b/src/mini_trainer/setup_model_for_training.py @@ -8,7 +8,7 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, ) from torch.distributed.device_mesh import init_device_mesh -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, Mxfp4Config +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from mini_trainer.utils import get_model_class_from_config, log_rank_0, patch_target_module from mini_trainer.osft_utils import OSFTModel, _build_osft_kwargs, _initialize_osft_with_distribution, _set_osft_dtypes, create_osft_model_class from mini_trainer.gpt_oss_utils import freeze_router_params, is_gpt_oss_model @@ -44,14 +44,18 @@ def wrap_fsdp2(model: torch.nn.Module) -> torch.nn.Module: # 3) Build a 1D device mesh over all ranks world_size = dist.get_world_size() - mesh = init_device_mesh("cuda", [world_size], mesh_dim_names=["fsdp"]) + device = next(model.parameters()).device.type + mesh = init_device_mesh(device, [world_size], mesh_dim_names=["fsdp"]) # 4) Mixed-precision policy using bfloat16 for Flash Attention compatibility # Flash Attention requires bfloat16 for proper operation - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - ) + if device == "cuda": + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + else: + mp_policy = MixedPrecisionPolicy() # 4) FSDP2 wrap each block for idx, block in enumerate(layers): @@ -181,6 +185,7 @@ def setup_osft_model( osft_upcast_dtype=torch.float32, osft_output_dtype=None, osft_memory_efficient_init: bool = False, + device: str = "cuda", ): """ High-level function to set up an OSFT model with all necessary configuration. @@ -239,7 +244,7 @@ def setup_osft_model( _set_osft_dtypes(model, osft_upcast_dtype, osft_output_dtype) # Handle initialization based on memory_efficient_init flag - device = torch.device("cuda", rank) + device = torch.device(device, rank) if osft_memory_efficient_init: # Memory-efficient: Initialize OSFT on CPU, then move to GPU @@ -268,6 +273,7 @@ def setup_model( osft_target_patterns: list[str] | None = None, use_liger_kernels: bool = False, osft_memory_efficient_init: bool = False, + device: str = "cuda", ) -> torch.nn.Module | OSFTModel: base_model_args = { "pretrained_model_name_or_path": model_name_or_path, @@ -282,6 +288,7 @@ def setup_model( if is_gpt_oss: try: # Try to specify the target dtype for dequantization + from transformers import Mxfp4Config quantization_config = Mxfp4Config(dequantize=True) # If the config supports dtype specification, use it if hasattr(quantization_config, 'torch_dtype'): @@ -293,18 +300,19 @@ def setup_model( log_rank_0("⚠️ GPT-OSS model detected but Mxfp4Config not available - using default config") # Check if flash_attn is available and set appropriate attention implementation - try: - import flash_attn - if is_gpt_oss: - base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3" - log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS") - else: - base_model_args["attn_implementation"] = "flash_attention_2" - except ImportError as e: - if os.environ.get("TESTING", "false").lower() == "true": - base_model_args["attn_implementation"] = "eager" - else: - raise e + if device != "hpu": + try: + import flash_attn + if is_gpt_oss: + base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3" + log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS") + else: + base_model_args["attn_implementation"] = "flash_attention_2" + except ImportError as e: + if os.environ.get("TESTING", "false").lower() == "true": + base_model_args["attn_implementation"] = "eager" + else: + raise e tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) @@ -359,6 +367,7 @@ def load_osft_model(): osft_upcast_dtype=osft_upcast_dtype, osft_output_dtype=effective_osft_output_dtype, osft_memory_efficient_init=osft_memory_efficient_init, + device=device, ) # Choose whether to apply orthogonal subspace learning (OSL) based on `osft` flag diff --git a/src/mini_trainer/train.py b/src/mini_trainer/train.py index d212d46..24c3486 100644 --- a/src/mini_trainer/train.py +++ b/src/mini_trainer/train.py @@ -308,7 +308,13 @@ def compute_validation_loss(model, val_data_loader, device): } # Forward pass - output = model(**model_inputs) + hpu_args = {} + if device.type == "hpu": + hpu_args = { + "use_flash_attention": True, + "lazy_mode": False, + } + output = model(**model_inputs, **hpu_args) loss = output.loss.float().sum() loss_metrics = loss.detach().item() @@ -718,7 +724,8 @@ def train( for batch in data_loader_it: batch_start_time = time.time() batch_totals.reset_batch() - torch.cuda.reset_peak_memory_stats() + if device.type == "cuda": + torch.cuda.reset_peak_memory_stats() for grad_accum, mb in enumerate(batch): mb_start_time = time.time() mb_num_loss_counted_tokens = mb['num_loss_counted_tokens'] @@ -732,7 +739,13 @@ def train( 'position_ids': mb['position_ids'].to(device), } - output = model(**model_inputs) + hpu_args = {} + if device.type == "hpu": + hpu_args = { + "use_flash_attention": True, + "lazy_mode": False, + } + output = model(**model_inputs, **hpu_args) # GPT-OSS: add auxiliary loss if present, otherwise use standard loss if hasattr(output, 'aux_loss') and output.aux_loss is not None: @@ -788,7 +801,7 @@ def train( "total_samples_accumulated": total_samples_accumulated, "total_tokens_accumulated": total_tokens_processed, "samples_per_second": bm['num_samples']/batch_time if batch_time > 0 else 0.0, - "peak_memory_usage_GB": float(torch.cuda.max_memory_allocated() / 1e9), + "peak_memory_usage_GB": float(torch.cuda.max_memory_allocated() / 1e9) if device.type == "cuda" else 0.0, 'val_loss': last_validation_loss, } # Add validation metrics if it's time to validate @@ -962,9 +975,14 @@ def main( wandb_project: Annotated[str | None, Option(help="Weights & Biases project name")] = None, wandb_run_name: Annotated[str | None, Option(help="Weights & Biases run name")] = None, wandb_entity: Annotated[str | None, Option(help="Weights & Biases entity/team name")] = None, + + # HPU specific parameters + device: Annotated[str, Option(help="Device to use for training ('cuda' or 'hpu')")] = "cuda", + torch_compile: Annotated[bool, Option(help="Enable torch.compile (HPU only)")] = False, + num_chunks: Annotated[int, Option(help="Number of chunks to split dataset into for sequential training")] = 1, ): - init_distributed_environment() + init_distributed_environment(device) # TODO: make the path creation lazy, but confirm that we can write to the given directory # at this point output_path = Path(output_dir) @@ -1037,6 +1055,9 @@ def main( "GLOBAL_RANK": global_rank, "NODE_RANK": node_rank, "WORLD_SIZE": world_size, + "device": device, + "torch_compile": torch_compile, + "num_chunks": num_chunks, } # Initialize wandb with the same params config @@ -1114,6 +1135,7 @@ def main( osft_upcast_dtype=osft_upcast_dtype_torch, osft_output_dtype=osft_output_dtype_torch, osft_memory_efficient_init=osft_memory_efficient_init, + device=device, ) model, optimizer, lr_scheduler = setup_training_components( model=model, diff --git a/src/mini_trainer/training_types.py b/src/mini_trainer/training_types.py index 8217f91..b807822 100644 --- a/src/mini_trainer/training_types.py +++ b/src/mini_trainer/training_types.py @@ -108,3 +108,9 @@ class TrainingArgs: # from train.py: save_best_val_loss: bool = field(default=False, metadata={"help": "Whether to save checkpoints when validation loss improves"}) val_loss_improvement_threshold: float = field(default=0.0, metadata={"help": "Minimum validation loss improvement required to trigger a save"}) + + # HPU specific parameters + device: str = field(default="cuda", metadata={"help": "Device to use for training ('cuda' or 'hpu')"}) + torch_compile: bool = field(default=False, metadata={"help": "Enable torch.compile (HPU only)"}) + num_chunks: int = field(default=1, metadata={"help": "Number of chunks to split dataset into for sequential training"}) + \ No newline at end of file diff --git a/src/mini_trainer/utils.py b/src/mini_trainer/utils.py index 3ba4f51..7190f45 100644 --- a/src/mini_trainer/utils.py +++ b/src/mini_trainer/utils.py @@ -64,13 +64,11 @@ def patch_target_module( setattr(source, obj_name_to_patch, replace_with) -def check_distributed_is_synchronized(): +def check_distributed_is_synchronized(device): """ This function runs a simple check to verify that torch.distributed is functioning properly and all processes are synchronized. """ - local_rank = int(os.environ["LOCAL_RANK"]) - device = torch.device("cuda", local_rank) t = torch.tensor([1]).to(device, torch.int32) # Here, every process group increments the counter @@ -82,7 +80,7 @@ def check_distributed_is_synchronized(): assert t.item() == dist.get_world_size(), "❌ Error: distributed check failed" -def check_distributed_is_evenly_configured(): +def check_distributed_is_evenly_configured(device): """ DDP, FSDP1, and FSDP2 do not support uneven world-size configurations, and therefore neither do our distributed computing algorithms (e.g. distributed SVD init). @@ -99,7 +97,6 @@ def check_distributed_is_evenly_configured(): f"world_size ({world_size}) is not cleanly divisible by local_world_size ({local_world_size}). Each node must have the same number of GPUs." ) - device = torch.device("cuda", local_rank) max_local_rank_seen = torch.tensor([local_rank], dtype=torch.int32, device=device) dist.all_reduce(max_local_rank_seen, op=dist.ReduceOp.MAX) if max_local_rank_seen[0] != local_world_size - 1: @@ -108,18 +105,24 @@ def check_distributed_is_evenly_configured(): ) -def init_distributed_environment(): +def init_distributed_environment(device: str = "cuda"): local_rank = int(os.environ["LOCAL_RANK"]) - device = torch.device("cuda", local_rank) + device = torch.device(device, local_rank) torch.distributed.init_process_group( - "nccl", timeout=timedelta(minutes=180), device_id=device + "nccl" if device.type == "cuda" else "hccl", + timeout=timedelta(minutes=180), + device_id=device ) # NOTE(osilkin): PyTorch wants us to avoid this API in favor of setting the device explicitly # through `init_process_group`, but without setting this, FSDP2 will shard the # entire model onto the first GPU. I haven't yet figured out a solution to this. - torch.cuda.set_device(local_rank) - check_distributed_is_synchronized() - check_distributed_is_evenly_configured() + if device.type == "cuda": + torch.cuda.set_device(local_rank) + else: + torch.hpu.set_device(local_rank) + + check_distributed_is_synchronized(device) + check_distributed_is_evenly_configured(device) log_rank_0("✅ Torch distributed appears to be functioning correctly") torch.distributed.barrier()