Skip to content

Commit ea0fca2

Browse files
committed
add multi node support
1 parent 2c9e31c commit ea0fca2

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

scripts/context_length_test/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ python search_context_length_capacity.py \
5353
|--------|--------|-----------|
5454
| `--start_length` | `4096` | Initial context length to begin testing. |
5555
| `--log_dir` | `./logs` | Directory to save logs and results. |
56+
| `--checkpoint_path` | `os.environ.get("TRINITY_CHECKPOINT_ROOT_DIR", "./checkpoints/length-test")` | Checkpoint path for testing. Note that this directory will be deleted during the test, please specify a path that is not used by other processes. |
5657
| `--test_gpu_num` | `1 2 4 6` | List of GPU counts to test scalability. |
5758
| `--test_sp_num` | `1` | Sequence parallel group sizes to evaluate. Must divide `test_gpu_num` and number of attention heads. |
5859
| `--save_hf_checkpoint` | `last` | When to save HF format checkpoints (`always`, `never`, `last`). |
5960
| `--entropy_saving` | `False` | Enable memory-saving techniques (if supported). |
6061
| `--offload` | `False` | Offload parameters to CPU to reduce GPU memory usage. |
6162
| `--trainer_strategy` | `fsdp` | Distributed training strategy (`fsdp` or `fsdp2`). |
6263
| `--timeout` | `2400` (40 min) | Maximum time per job before forced termination. |
64+
| `--dlc` | `False` | Specify when running in Aliyun PAI DLC. |
6365

6466
---
6567

scripts/context_length_test/search_context_length_capacity.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import transformers
1616
import yaml
1717

18+
from trinity.utils.dlc_utils import is_running, setup_ray_cluster, stop_ray_cluster
19+
1820
# Default list of GPU counts to test
1921
DEFAULT_GPU_NUMS: List[int] = [1, 2, 4, 6]
2022
EXCEPTION_STRING = "Traceback (most recent call last)"
@@ -51,6 +53,9 @@ def monitor_output(
5153
if EXCEPTION_STRING in line:
5254
exception_event.set()
5355

56+
if exception_event.is_set():
57+
print(line, end="", flush=True)
58+
5459
# Check for oom
5560
if OOM_STRING in line:
5661
exception_event.set()
@@ -64,6 +69,7 @@ def run_command_with_monitor(
6469
command: List[str],
6570
envs: dict[str, str],
6671
log_path: str,
72+
checkpoint_path: str,
6773
timeout: Optional[int] = None,
6874
) -> bool:
6975
"""Runs a shell command with real-time output monitoring and early termination support.
@@ -77,26 +83,27 @@ def run_command_with_monitor(
7783
command: Command to execute, as a list of strings.
7884
envs: Environment variables to set for the command.
7985
log_path: Path to the log file where output will be saved.
86+
checkpoint_path: Path to the checkpoint directory.
8087
timeout: Optional timeout in seconds before forcing termination.
8188
8289
Returns:
8390
True if the command completed successfully without OOM error; False otherwise.
8491
"""
8592
retry_flag = True
8693
success_flag = False
87-
checkpoint_root = os.environ.get("TRINITY_CHECKPOINT_ROOT_DIR", "./checkpoints/length-test")
94+
envs["TRINITY_CHECKPOINT_ROOT_DIR"] = checkpoint_path
95+
process_env = os.environ.copy()
96+
process_env.update(envs)
8897

8998
while retry_flag:
9099
# Clean up checkpoint directory before each run
91-
shutil.rmtree(checkpoint_root, ignore_errors=True)
100+
shutil.rmtree(checkpoint_path, ignore_errors=True)
92101

93102
exception_event = threading.Event()
94103
oom_event = threading.Event()
95104

96105
with open(log_path, "w", encoding="utf-8") as log_file:
97106
# Start subprocess with merged stdout/stderr
98-
process_env = os.environ.copy()
99-
process_env.update(envs)
100107
process = subprocess.Popen(
101108
command,
102109
stdout=subprocess.PIPE,
@@ -160,6 +167,7 @@ def run_command_with_monitor(
160167
def find_max_model_len(
161168
model_path: str,
162169
model_config,
170+
checkpoint_path: str,
163171
trainer_gpu_num: int,
164172
sp_num: int,
165173
base_log_dir: str,
@@ -178,6 +186,7 @@ def find_max_model_len(
178186
Args:
179187
model_path: Path to the pretrained model.
180188
model_config: Loaded Hugging Face model configuration.
189+
checkpoint_path: Path to the checkpoint directory.
181190
trainer_gpu_num: Number of GPUs allocated.
182191
sp_num: Number of sequence parallel groups.
183192
base_log_dir: Base directory for saving logs.
@@ -253,6 +262,7 @@ def find_max_model_len(
253262
cmd_base,
254263
cmd_env,
255264
logfile,
265+
checkpoint_path,
256266
timeout=timeout,
257267
)
258268

@@ -278,6 +288,13 @@ def find_max_model_len(
278288

279289
def main(args):
280290
"""Main entry point: orchestrates multi-GPU, multi-SP context length testing."""
291+
if args.dlc:
292+
cluster_namespace = "search_context_length_capacity"
293+
setup_ray_cluster(namespace=cluster_namespace)
294+
295+
if not is_running():
296+
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
297+
281298
os.makedirs(args.log_dir, exist_ok=True)
282299

283300
model_name = os.path.basename(args.model_path)
@@ -300,6 +317,7 @@ def main(args):
300317
max_length = find_max_model_len(
301318
model_path=args.model_path,
302319
model_config=model_config,
320+
checkpoint_path=args.checkpoint_path,
303321
trainer_gpu_num=trainer_gpu_num,
304322
sp_num=sp_num,
305323
base_log_dir=args.log_dir,
@@ -319,6 +337,9 @@ def main(args):
319337
f"max_model_len = {max_length}"
320338
)
321339

340+
if args.dlc:
341+
stop_ray_cluster(namespace=cluster_namespace)
342+
322343

323344
if __name__ == "__main__":
324345
default_log_dir = os.path.join(os.path.dirname(__file__), "logs")
@@ -343,6 +364,14 @@ def main(args):
343364
default=default_log_dir,
344365
help="Directory to store experiment logs.",
345366
)
367+
parser.add_argument(
368+
"--checkpoint_path",
369+
type=str,
370+
default=os.environ.get("TRINITY_CHECKPOINT_ROOT_DIR", "./checkpoints/length-test"),
371+
help="Checkpoint path for testing. "
372+
"Note that this directory will be deleted during the test, "
373+
"please specify a path that is not used by other processes.",
374+
)
346375
parser.add_argument(
347376
"--test_gpu_num",
348377
type=int,
@@ -387,6 +416,9 @@ def main(args):
387416
default=2400,
388417
help="Timeout for each experiment in seconds.",
389418
)
419+
parser.add_argument(
420+
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
421+
)
390422

391423
args = parser.parse_args()
392424
main(args)

0 commit comments

Comments
 (0)