|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import subprocess |
| 17 | +from dataclasses import dataclass, field |
| 18 | + |
| 19 | +import nemo_run as run |
| 20 | +from nemo.collections import llm |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class SlurmConfig: |
| 25 | + """Configuration for SlurmExecutor.""" |
| 26 | + |
| 27 | + account: str = "" # Your Slurm account |
| 28 | + partition_cpu: str = "" # Slurm CPU partition to use |
| 29 | + partition_gpu: str = "" # Slurm GPU partition to use |
| 30 | + time: str = "" # Job time limit (HH:MM:SS) |
| 31 | + container_image: str = "" # Container image for jobs |
| 32 | + env_vars: dict[str, str] = field(default_factory=dict) # Environment variables to set |
| 33 | + container_mounts: list[str] = field(default_factory=list) # Container mounts |
| 34 | + use_local_tunnel: bool = False # Set to True if running from within the cluster |
| 35 | + host: str = "" # Required for SSH tunnel: Slurm cluster hostname |
| 36 | + user: str = "" # Required for SSH tunnel: Your username |
| 37 | + job_dir: str = "" # Required for SSH tunnel: Directory to store runs on cluster |
| 38 | + identity: str | None = None # Optional for SSH tunnel: Path to SSH key for authentication |
| 39 | + |
| 40 | + def __post_init__(self): |
| 41 | + """Validate the configuration and raise descriptive errors.""" |
| 42 | + if not self.account: |
| 43 | + raise ValueError("SlurmConfig.account must be set to your actual Slurm account") |
| 44 | + if not self.partition_cpu: |
| 45 | + raise ValueError("SlurmConfig.partition_cpu must be set") |
| 46 | + if not self.partition_gpu: |
| 47 | + raise ValueError("SlurmConfig.partition_gpu must be set") |
| 48 | + if not self.time: |
| 49 | + raise ValueError("SlurmConfig.time must be set to job time limit (e.g., '02:00:00')") |
| 50 | + if not self.container_image: |
| 51 | + raise ValueError("SlurmConfig.container_image must be set to container image for jobs") |
| 52 | + if not self.use_local_tunnel: |
| 53 | + # Only validate SSH tunnel settings if not using local tunnel |
| 54 | + if not self.host: |
| 55 | + raise ValueError( |
| 56 | + "SlurmConfig.host must be set to your actual cluster hostname when using SSH tunnel" |
| 57 | + ) |
| 58 | + if not self.user: |
| 59 | + raise ValueError( |
| 60 | + "SlurmConfig.user must be set to your actual username when using SSH tunnel" |
| 61 | + ) |
| 62 | + if not self.job_dir: |
| 63 | + raise ValueError( |
| 64 | + "SlurmConfig.job_dir must be set to directory for storing runs on cluster" |
| 65 | + ) |
| 66 | + |
| 67 | + self.env_vars |= { |
| 68 | + "CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance |
| 69 | + "TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace |
| 70 | + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory |
| 71 | + "NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory |
| 72 | + } |
| 73 | + |
| 74 | + |
| 75 | +def create_slurm_executor( |
| 76 | + slurm_cfg: SlurmConfig, nodes: int = 1, ntasks_per_node: int = 1, num_gpus: int = 0 |
| 77 | +): |
| 78 | + # Configure tunnel |
| 79 | + if slurm_cfg.use_local_tunnel: |
| 80 | + # Use LocalTunnel when already on the cluster |
| 81 | + tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir) |
| 82 | + else: |
| 83 | + # Use SSH tunnel when launching from local machine |
| 84 | + tunnel = run.SSHTunnel( |
| 85 | + host=slurm_cfg.host, |
| 86 | + user=slurm_cfg.user, |
| 87 | + job_dir=slurm_cfg.job_dir, |
| 88 | + identity=slurm_cfg.identity, # can be None |
| 89 | + ) |
| 90 | + |
| 91 | + if num_gpus > 0: |
| 92 | + return run.SlurmExecutor( |
| 93 | + account=slurm_cfg.account, |
| 94 | + partition=slurm_cfg.partition_gpu, |
| 95 | + ntasks_per_node=ntasks_per_node, |
| 96 | + gpus_per_node=num_gpus, |
| 97 | + nodes=nodes, |
| 98 | + tunnel=tunnel, |
| 99 | + container_image=slurm_cfg.container_image, |
| 100 | + container_mounts=slurm_cfg.container_mounts, |
| 101 | + time=slurm_cfg.time, |
| 102 | + packager=run.GitArchivePackager(), |
| 103 | + mem="0", |
| 104 | + gres=f"gpu:{num_gpus}", |
| 105 | + ) |
| 106 | + else: |
| 107 | + return run.SlurmExecutor( |
| 108 | + account=slurm_cfg.account, |
| 109 | + partition=slurm_cfg.partition_cpu, |
| 110 | + nodes=nodes, |
| 111 | + tunnel=tunnel, |
| 112 | + container_image=slurm_cfg.container_image, |
| 113 | + container_mounts=slurm_cfg.container_mounts, |
| 114 | + time=slurm_cfg.time, |
| 115 | + packager=run.GitArchivePackager(), |
| 116 | + mem="0", |
| 117 | + ) |
| 118 | + |
| 119 | + |
| 120 | +def get_finetune_recipe(recipe_name: str): |
| 121 | + if not hasattr(getattr(llm, recipe_name), "finetune_recipe"): |
| 122 | + raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe") |
| 123 | + return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None) |
| 124 | + |
| 125 | + |
| 126 | +def read_chat_template(template_path: str): |
| 127 | + with open(template_path) as f: |
| 128 | + return f.read().strip() |
| 129 | + |
| 130 | + |
| 131 | +def download_hf_dataset(dataset_name: str, output_dir: str | None = None): |
| 132 | + """Download a dataset from HuggingFace Hub using huggingface-cli.""" |
| 133 | + cmd = ["huggingface-cli", "download", dataset_name, "--repo-type", "dataset"] |
| 134 | + |
| 135 | + if output_dir: |
| 136 | + cmd.extend(["--local-dir", output_dir]) |
| 137 | + |
| 138 | + subprocess.run(cmd, check=True) |
| 139 | + print(f"Successfully downloaded dataset: {dataset_name}") |
0 commit comments