-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmulti_host_test.py
More file actions
executable file
·91 lines (72 loc) · 2.95 KB
/
multi_host_test.py
File metadata and controls
executable file
·91 lines (72 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
"""
Multi-host TPU pod test script.
This script should be run on ALL hosts simultaneously.
Usage with gcloud:
gcloud compute tpus tpu-vm ssh YOUR_TPU \
--zone YOUR_ZONE \
--worker=all \
--command="cd ~/tunix && python3 multi_host_test.py"
"""
import os
import sys
import jax
import jax.numpy as jnp
def print_separator(char="=", length=60):
print(char * length)
def main():
print_separator()
print(f"JAX Multi-Host TPU Test")
print_separator()
try:
# Basic JAX info
print(f"\n[Process {jax.process_index()}/{jax.process_count()}]")
print(f"JAX version: {jax.__version__}")
print(f"Hostname: {os.environ.get('HOSTNAME', 'unknown')}")
# Device counts
print(f"\nDevice Information:")
print(f" Local devices (this host): {jax.local_device_count()}")
print(f" Global devices (all hosts): {jax.device_count()}")
print(f" Process index: {jax.process_index()}")
print(f" Process count: {jax.process_count()}")
# Show local devices
print(f"\nLocal devices on this host:")
for i, device in enumerate(jax.local_devices()):
print(f" {i}: {device}")
# Only process 0 shows global info to avoid spam
if jax.process_index() == 0:
print(f"\nAll devices across all hosts:")
for i, device in enumerate(jax.devices()):
print(f" {i}: {device}")
# Synchronization test
print(f"\n[Process {jax.process_index()}] Synchronizing...")
jax.experimental.multihost_utils.sync_global_devices("test_sync")
if jax.process_index() == 0:
print(f"\n✓ All {jax.process_count()} hosts synchronized!")
# Simple computation test
print(f"\n[Process {jax.process_index()}] Running computation...")
# Create array on local devices
local_size = 100
x = jnp.ones((local_size, local_size))
y = x @ x # Matrix multiplication
result = y.sum()
print(f"[Process {jax.process_index()}] Local computation result: {result}")
# Final sync
jax.experimental.multihost_utils.sync_global_devices("test_complete")
if jax.process_index() == 0:
print(f"\n" + "=" * 60)
print(f"✓ SUCCESS: Multi-host TPU pod is working!")
print(f"=" * 60)
print(f"\nSummary:")
print(f" Total hosts: {jax.process_count()}")
print(f" Total TPU chips: {jax.device_count()}")
print(f" Chips per host: {jax.local_device_count()}")
print(f"\nYour TPU pod is ready for training!")
return 0
except Exception as e:
print(f"\n✗ ERROR on process {jax.process_index()}: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())