Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit 872fe5a

Browse files
authored
[FEATURE] pass global config to worker and set manual sharding of intermediates (#928)
1 parent 4e92039 commit 872fe5a

File tree

6 files changed

+218
-43
lines changed

6 files changed

+218
-43
lines changed

alpa/device_mesh.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class MeshHostWorker:
111111

112112
def __init__(self, server_address: str, num_hosts: int, host_id: int,
113113
mesh_id: int, move_worker: DaemonMoveWorker,
114-
runtime_random_seed: int):
114+
runtime_random_seed: int, worker_global_config: dict):
115115
self.num_hosts = num_hosts
116116
self.host_id = host_id
117117
self.mesh_id = mesh_id
@@ -124,6 +124,9 @@ def __init__(self, server_address: str, num_hosts: int, host_id: int,
124124
self.distributed_client.connect()
125125
logger.debug(
126126
f"{host_id}: Success to connect to xla runtime at {server_address}")
127+
128+
# Set global config to follow the driver
129+
global_config.update_worker_config(worker_global_config)
127130
if global_config.backend == "gpu":
128131
self.backend = xla_client.make_gpu_client(self.distributed_client,
129132
node_id=host_id)
@@ -1139,7 +1142,8 @@ def launch_xla_servers(self):
11391142
"env_vars": env_vars
11401143
}).remote(server_address, self.num_hosts, i,
11411144
self.mesh_id, move_worker,
1142-
global_config.runtime_random_seed)
1145+
global_config.runtime_random_seed,
1146+
global_config)
11431147
workers.append(worker)
11441148
return service_server, workers
11451149

alpa/global_env.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ def ray_accelerator_name(self):
105105
backend_to_ray = {"gpu": "GPU"}
106106
return backend_to_ray[self.backend]
107107

108+
def update_worker_config(self, cfg: "GlobalConfig"):
109+
"""Update the worker config based on the host one"""
110+
self.backend = cfg.backend
111+
# Random seed used for compilation
112+
self.compile_random_seed = cfg.compile_random_seed
113+
# Random seed used for runtime
114+
self.runtime_random_seed = cfg.runtime_random_seed
115+
# XLA server port range
116+
self.xla_server_port_start = cfg.xla_server_port_start
117+
self.xla_server_port_end = cfg.xla_server_port_end
118+
# XLA gpu kernel auto-tuning level
119+
self.xla_gpu_autotune_level = cfg.xla_gpu_autotune_level
120+
# Whether to use AWS EFA network interface
121+
self.use_aws_efa = cfg.use_aws_efa
122+
########## Options of pipeline runtime ##########
123+
# Whether to sync before and after the executable for accurate internal
124+
# timer
125+
self.pipeline_sync_for_timer = cfg.pipeline_sync_for_timer
126+
# Whether to use single-byte signal tensor for send/recv.
127+
# This is a debug option.
128+
self.pipeline_use_signal_send_recv = cfg.pipeline_use_signal_send_recv
129+
# Whether to use the scatter-gater/local-all-gather optimization.
130+
self.use_local_allgather = cfg.use_local_allgather
131+
# Cross mesh resharding mode. Possible choices: {"send_recv",
132+
# "broadcast"}
133+
self.resharding_mode = cfg.resharding_mode
134+
self.nccl_mode = cfg.nccl_mode
135+
self.enable_overlapping = cfg.enable_overlapping
136+
self.collect_trace = cfg.collect_trace
137+
108138

109139
global_config = GlobalConfig()
110140

alpa/pipeline_parallel/compile_executable.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable
1616
from alpa.pipeline_parallel.runtime_emitter import (
1717
OverlapFriendlyPipelineInstEmitter, PipelineInstEmitter)
18-
from alpa.pipeline_parallel.schedules import (GpipeSchedule,
19-
OverlapFriendlyPipeDreamSchedule,
20-
PipeDreamFlush, InferenceSchedule)
18+
from alpa.pipeline_parallel.schedules import create_pipeline_schedule
2119
from alpa.pipeline_parallel.computation import (
2220
create_donation_mapping, generate_computations_from_modules,
2321
generate_sharded_xla_computations,
@@ -38,6 +36,7 @@
3836
from alpa.shard_parallel.manual_sharding import (ManualShardingOption,
3937
ParsedManualShardingOption,
4038
get_flatten_axis_resources,
39+
get_intermediate_parsed_spec,
4140
parsed_spec_to_opsharding)
4241
from alpa.util import (get_var_mapping, trace_jaxpr_with_micro_batch,
4342
OrderedSet, GradFuncTransformContext)
@@ -198,31 +197,14 @@ def compile_pipeshard_executable_internal(
198197
debug_compilation_time("apply grad")
199198

200199
# Generate pipeline schedule and placement
201-
dependency = gen_dependency_with_stages(jax_pipeline_stages,
202-
sliced_apply_grad_stages)
203-
if pipeline_schedule == "gpipe":
204-
schedule = GpipeSchedule(dependency=dependency,
205-
meshes=sliced_virtual_meshes,
206-
apply_grad_placement=apply_grad_placement,
207-
num_batch=num_microbatch)
208-
elif pipeline_schedule == "1f1b":
209-
schedule = PipeDreamFlush(dependency=dependency,
210-
meshes=sliced_virtual_meshes,
211-
apply_grad_placement=apply_grad_placement,
212-
num_batch=num_microbatch)
213-
elif pipeline_schedule == "inference":
214-
schedule = InferenceSchedule(dependency=dependency,
215-
meshes=sliced_virtual_meshes,
216-
apply_grad_placement=apply_grad_placement,
217-
num_batch=num_microbatch)
218-
elif pipeline_schedule == "1f1b_overlap_friendly":
219-
schedule = OverlapFriendlyPipeDreamSchedule(
220-
dependency=dependency,
221-
meshes=sliced_virtual_meshes,
222-
apply_grad_placement=apply_grad_placement,
223-
num_batch=num_microbatch)
224-
else:
225-
raise ValueError(f"Invalid schedule: {pipeline_schedule}")
200+
dependency, fwd_intermediates = gen_dependency_with_stages(
201+
jax_pipeline_stages, num_meshes, sliced_apply_grad_stages)
202+
schedule = create_pipeline_schedule(
203+
pipeline_schedule,
204+
dependency=dependency,
205+
meshes=sliced_virtual_meshes,
206+
apply_grad_placement=apply_grad_placement,
207+
num_batch=num_microbatch)
226208

227209
# Forcibly set the sharding specs of global invars and outvars.
228210
# FIXME(yonghao): the invar can appear on multiple meshes and thus different
@@ -245,7 +227,7 @@ def compile_pipeshard_executable_internal(
245227
output_sharding_dicts) = get_manual_input_output_sharding_specs(
246228
jax_all_stages, manual_stage_option.submesh_logical_shapes,
247229
parsed_manual_sharding_option, global_invars, global_outvars,
248-
schedule.stage_mesh_mapping)
230+
schedule.stage_mesh_mapping, fwd_intermediates)
249231
else:
250232
input_sharding_dicts = [input_sharding_dict] * num_meshes
251233
output_sharding_dicts = [output_sharding_dict] * num_meshes
@@ -353,7 +335,7 @@ def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr,
353335

354336
def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option,
355337
global_invars, global_outvars,
356-
stage_to_mesh):
338+
stage_to_mesh, fwd_intermediates):
357339
"""
358340
Split user assigned input and output PartitionSpec into sharding specs for
359341
each pipeline stage.
@@ -363,19 +345,33 @@ def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option,
363345
var_to_pspec = {}
364346
handle_invar = False
365347
handle_outvar = False
348+
# Add global input and output's parsed partition spec.
366349
if ms_option.in_parsed_pspec is not None:
367350
var_to_pspec.update(dict(zip(global_invars, ms_option.in_parsed_pspec)))
368351
handle_invar = True
369352
if ms_option.out_parsed_pspec is not None:
370353
var_to_pspec.update(
371354
dict(zip(global_outvars, ms_option.out_parsed_pspec)))
372355
handle_outvar = True
356+
# Add pipeline intermediate's parsed partition spec.
357+
intermediate_to_pspec = {}
358+
if ms_option.pipeline_intermediate_axes is not None:
359+
for v in fwd_intermediates:
360+
# TODO: This is a simple heuristic: we simply replicate 1d tensors.
361+
if len(v.aval.shape) <= 1:
362+
continue
363+
intermediate_to_pspec[v] = get_intermediate_parsed_spec(
364+
ms_option.pipeline_intermediate_axes, len(v.aval.shape))
365+
373366
submesh_axis_names = ms_option.submesh_axis_names
374367
if submesh_axis_names is None:
375368
submesh_axis_names = [ms_option.mesh_axis_names] * len(mesh_shapes)
376369

377370
def get_vars_to_sharding_specs(variables, mesh_shape, mesh_axis_names):
378-
parsed_specs = [var_to_pspec[v] for v in variables]
371+
parsed_specs = [
372+
(var_to_pspec[v] if v in var_to_pspec else intermediate_to_pspec[v])
373+
for v in variables
374+
]
379375
avals = [v.aval for v in variables]
380376
var_op_shardings = parsed_spec_to_opsharding(parsed_specs, avals,
381377
mesh_shape,
@@ -398,8 +394,13 @@ def get_vars_to_sharding_specs(variables, mesh_shape, mesh_axis_names):
398394
# invars
399395
if handle_invar:
400396
invar_in_global = [var for var in stage.invars if var in invar_set]
397+
# add intermediate vars
398+
intermediate_var = [
399+
var for var in stage.invars if var in intermediate_to_pspec
400+
]
401+
invars = invar_in_global + intermediate_var
401402
stage_invar_shardings = get_vars_to_sharding_specs(
402-
invar_in_global, mesh_shape, mesh_axis_names)
403+
invars, mesh_shape, mesh_axis_names)
403404
else:
404405
stage_invar_shardings = {}
405406
# outvars
@@ -458,13 +459,17 @@ def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes,
458459
compile_intermediate = [None] * num_meshes
459460
total_flops = 0
460461
for mesh_idx in range(num_meshes):
461-
input_sharding_dict = input_sharding_dicts[mesh_idx]
462-
output_sharding_dict = output_sharding_dicts[mesh_idx]
463462
virtual_mesh = virtual_meshes[mesh_idx]
464463
logical_mesh = virtual_mesh.get_logical_mesh(
465464
logical_mesh_shapes[mesh_idx])
466465
autosharding_option = dataclasses.replace(
467466
default_as_option, **autosharding_option_dicts[mesh_idx])
467+
468+
# Predefined shardings. stage_input_sharding should have shardings for
469+
# all parameters, while the sharding dict can have only a portion of
470+
# all parameters.
471+
input_sharding_dict = input_sharding_dicts[mesh_idx]
472+
output_sharding_dict = output_sharding_dicts[mesh_idx]
468473
stage_input_sharding = stage_input_shardings[mesh_idx]
469474

470475
# Setup dummy stages

alpa/pipeline_parallel/schedules.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import itertools
33
import logging
44
from abc import abstractmethod, ABCMeta
5-
from typing import List, Tuple
5+
from typing import Dict, List, Tuple
66

77
import numpy as np
88

@@ -15,23 +15,29 @@
1515

1616
def gen_dependency_with_stages(
1717
compute_stages: List[PipelineComputation],
18+
num_mesh: int,
1819
apply_grad_stages: List[PipelineComputation] = ()):
1920
"""Generate the dependency matrix for a list of pipeline stages."""
2021
n_stages = len(compute_stages) + len(apply_grad_stages)
2122
d = np.zeros([n_stages, n_stages], dtype=int)
2223
var_stage_id = {}
24+
fwd_intermediate_vars = OrderedSet()
2325
for i, stage in enumerate(itertools.chain(compute_stages,
2426
apply_grad_stages)):
2527
for var in stage.invars:
2628
if var in var_stage_id:
2729
d[i, var_stage_id[var]] = 1
30+
if i < num_mesh and var_stage_id[var] != 2 * num_mesh - i - 1:
31+
# not the var from forward to backward. we don't care them.
32+
# not the var on the backward side
33+
fwd_intermediate_vars.add(var)
2834
else:
2935
# Assume the var is from global_invars
3036
pass
3137
for var in stage.outvars:
3238
var_stage_id[var] = i
3339

34-
return d
40+
return d, fwd_intermediate_vars
3541

3642

3743
def gen_linear_pipeline_dependency(num_stage):
@@ -510,3 +516,18 @@ def _generate_schedule(self):
510516
scheds[mesh_idx] = (self.last_backward_batch_index, stage_idx)
511517
schedules.append(scheds)
512518
return schedules
519+
520+
521+
pipeline_schedule: Dict[str, PipelineSchedule] = {}
522+
pipeline_schedule["gpipe"] = GpipeSchedule
523+
pipeline_schedule["1f1b"] = PipeDreamFlush
524+
pipeline_schedule["inference"] = InferenceSchedule
525+
pipeline_schedule["1f1b_overlap_friendly"] = OverlapFriendlyPipeDreamSchedule
526+
527+
528+
def create_pipeline_schedule(name, dependency, meshes, apply_grad_placement,
529+
num_batch):
530+
return pipeline_schedule[name](dependency=dependency,
531+
meshes=meshes,
532+
apply_grad_placement=apply_grad_placement,
533+
num_batch=num_batch)

alpa/shard_parallel/manual_sharding.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""User specified manual sharding strategy following pjit's api."""
22
import dataclasses
3-
from typing import Any, Optional, OrderedDict, Tuple, Union
3+
from typing import Any, Optional, OrderedDict, Sequence, Tuple, Union
44

55
from jax._src.lib import xla_client as xc
66
from jax._src.tree_util import _replace_nones
77
from jax._src.util import safe_zip
88
from jax.experimental.pjit import (_is_unspecified, _is_auto, _is_from_gda,
99
_prepare_axis_resources, get_array_mapping,
10-
_UNSPECIFIED, ParsedPartitionSpec)
10+
_UNSPECIFIED, PartitionSpec,
11+
ParsedPartitionSpec)
1112
from jax.interpreters import mlir, pxla
1213
from jax.tree_util import tree_unflatten, tree_flatten, tree_map
1314

@@ -22,6 +23,12 @@ class ManualShardingOption:
2223
# According to pjit, None means replicated.
2324
in_axis_resources: Any = _UNSPECIFIED
2425
out_axis_resources: Any = _UNSPECIFIED
26+
# To enable data parallel for multiple pipeline stages, where the input
27+
# activation is not a global invar. Currently defined by (dim_name, dim_idx)
28+
# TODO: a better design to allow only applying this rule to a subset of
29+
# intermediate, because some pipeline communicated tensors do not have a
30+
# batch dim. e.g. the time vector in diffusion generated at the first stage.
31+
pipeline_intermediate_axes: Sequence[Tuple[str, int]] = None
2532

2633

2734
@dataclasses.dataclass
@@ -32,6 +39,7 @@ class ParsedManualShardingOption:
3239
# Parsed and flatten status
3340
in_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None
3441
out_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None
42+
pipeline_intermediate_axes: Sequence[Tuple[str, int]] = None
3543

3644

3745
def _parsed_pspec_to_hlo_sharding(
@@ -121,9 +129,9 @@ def get_flatten_axis_resources(sharding_option: ManualShardingOption, in_tree,
121129
else:
122130
out_axis_flat = _prepare_axis_and_flatten(
123131
sharding_option.out_axis_resources, out_tree, "out_axis_resources")
124-
return ParsedManualShardingOption(sharding_option.mesh_axis_names,
125-
sharding_option.submesh_axis_names,
126-
in_axis_flat, out_axis_flat)
132+
return ParsedManualShardingOption(
133+
sharding_option.mesh_axis_names, sharding_option.submesh_axis_names,
134+
in_axis_flat, out_axis_flat, sharding_option.pipeline_intermediate_axes)
127135

128136

129137
def parsed_spec_to_opsharding(axes, avals, mesh_shape, mesh_axis_names):
@@ -156,3 +164,17 @@ def get_manual_sharding_spec(
156164
parsed_resources.out_parsed_pspec, out_avals, mesh_shape,
157165
mesh_axis_names)
158166
return in_op_shardings, out_op_shardings
167+
168+
169+
def get_intermediate_parsed_spec(intermediate_dims,
170+
dim_len,
171+
allow_unconstrained_dims=False):
172+
axes = [None] * dim_len
173+
for (name, dim) in intermediate_dims:
174+
axes[dim] = name
175+
pspec = PartitionSpec(*axes)
176+
parsed_pspec = ParsedPartitionSpec.from_user_input(
177+
pspec,
178+
"intermediate specifications",
179+
allow_unconstrained_dims=allow_unconstrained_dims)
180+
return parsed_pspec

0 commit comments

Comments
 (0)