Skip to content

Commit 02d956c

Browse files
author
Orbax Authors
committed
Add benchmarks for P2P CheckpointManager.
PiperOrigin-RevId: 873854861
1 parent b912119 commit 02d956c

File tree

7 files changed

+900
-3
lines changed

7 files changed

+900
-3
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# The name for the entire test suite run.
2+
suite_name: "P2P CheckpointManager Benchmark"
3+
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice. Assuming 8 devices per slice.
7+
# DCN: Across slices.
8+
ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1}
9+
dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index
10+
process_is_granule: true
11+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
12+
ici_parallelism: {"data": 1, "model": 1}
13+
dcn_parallelism: {"data": 4, "model": 1}
14+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
15+
ici_parallelism: {"data": 1, "model": 16}
16+
dcn_parallelism: {"data": 4, "model": 1}
17+
allow_split_physical_axes: true
18+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
19+
ici_parallelism: {"data": 2, "model": 8}
20+
dcn_parallelism: {"data": 2, "model": 1}
21+
allow_split_physical_axes: true
22+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
23+
ici_parallelism: {"data": 2, "model": 4}
24+
dcn_parallelism: {"data": 2, "model": 1}
25+
allow_split_physical_axes: true
26+
27+
checkpoint_config:
28+
spec:
29+
a_1d: {dtype: "float32", shape: [32], sharding: [null]}
30+
b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]}
31+
c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]}
32+
d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]}
33+
e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]}
34+
f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]}
35+
g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]}
36+
h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]}
37+
i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]}
38+
j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]}
39+
k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]}
40+
custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]}
41+
42+
benchmarks:
43+
- generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark"
44+
options:
45+
persistent_save_interval_steps: [2]
46+
persistent_max_to_keep: [5]
47+
local_save_interval_steps: [2]
48+
local_max_to_keep: 2
49+
replica_axis_index: 0
50+
train_steps: 5
51+
experimental_orbax_use_distributed_process_id: true
52+
experimental_use_distributed_id_for_mesh_consistency: true

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def run(self, repeat_index: int | None = None) -> TestResult:
159159
path = directory_setup.setup_test_directory(
160160
self.name, self.output_dir, repeat_index
161161
)
162+
local_path = None
163+
if self.local_directory is not None:
164+
local_path = epath.Path(self.local_directory) / name
165+
if repeat_index is not None:
166+
local_path = local_path / f"repeat_{repeat_index}"
162167

163168
with benchmark_metrics.measure(
164169
"sync_global_processes:benchmark:setup_test_directory"
@@ -185,7 +190,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
185190
options=self.options,
186191
mesh=self.mesh,
187192
repeat_index=repeat_index,
188-
local_path=self.local_directory,
193+
local_path=local_path,
189194
)
190195

191196
test_context_summary = self._build_test_context_summary(context)

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def setup_test_directory(
4040
path = path / f"repeat_{repeat_index}"
4141
logging.info("Setting up test directory at: %s", path)
4242
if jax.process_index() == 0:
43-
if path.exists():
43+
if path.exists() and not base_path.startswith("gs://"):
4444
logging.warning("Test directory %s already exists. Deleting it.", path)
4545
path.rmtree()
46-
path.mkdir(parents=True, exist_ok=False)
46+
path.mkdir(parents=True, exist_ok=True)
4747
return path

0 commit comments

Comments
 (0)