|
| 1 | +# The name for the entire test suite run. |
| 2 | +suite_name: "P2P CheckpointManager Benchmark" |
| 3 | + |
| 4 | +mesh_configs: |
| 5 | + - mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"] |
| 6 | + # ICI: Within a slice. Assuming 8 devices per slice. |
| 7 | + # DCN: Across slices. |
| 8 | + ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1} |
| 9 | + dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index |
| 10 | + process_is_granule: true |
| 11 | + - mesh_axes: ["data", "model", "tensor", "fsdp"] |
| 12 | + ici_parallelism: {"data": 1, "model": 1} |
| 13 | + dcn_parallelism: {"data": 4, "model": 1} |
| 14 | + - mesh_axes: ["data", "model", "tensor", "fsdp"] |
| 15 | + ici_parallelism: {"data": 1, "model": 16} |
| 16 | + dcn_parallelism: {"data": 4, "model": 1} |
| 17 | + allow_split_physical_axes: true |
| 18 | + - mesh_axes: ["data", "model", "tensor", "fsdp"] |
| 19 | + ici_parallelism: {"data": 2, "model": 8} |
| 20 | + dcn_parallelism: {"data": 2, "model": 1} |
| 21 | + allow_split_physical_axes: true |
| 22 | + - mesh_axes: ["data", "model", "tensor", "fsdp"] |
| 23 | + ici_parallelism: {"data": 2, "model": 4} |
| 24 | + dcn_parallelism: {"data": 2, "model": 1} |
| 25 | + allow_split_physical_axes: true |
| 26 | + |
| 27 | +checkpoint_config: |
| 28 | + spec: |
| 29 | + a_1d: {dtype: "float32", shape: [32], sharding: [null]} |
| 30 | + b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]} |
| 31 | + c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]} |
| 32 | + d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]} |
| 33 | + e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]} |
| 34 | + f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]} |
| 35 | + g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]} |
| 36 | + h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]} |
| 37 | + i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]} |
| 38 | + j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]} |
| 39 | + k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]} |
| 40 | + custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]} |
| 41 | + |
| 42 | +benchmarks: |
| 43 | + - generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark" |
| 44 | + options: |
| 45 | + persistent_save_interval_steps: [2] |
| 46 | + persistent_max_to_keep: [5] |
| 47 | + local_save_interval_steps: [2] |
| 48 | + local_max_to_keep: 2 |
| 49 | + replica_axis_index: 0 |
| 50 | + train_steps: 5 |
| 51 | + experimental_orbax_use_distributed_process_id: true |
| 52 | + experimental_use_distributed_id_for_mesh_consistency: true |
0 commit comments