15
15
from alpa .pipeline_parallel .pipeshard_executable import PipeshardDriverExecutable
16
16
from alpa .pipeline_parallel .runtime_emitter import (
17
17
OverlapFriendlyPipelineInstEmitter , PipelineInstEmitter )
18
- from alpa .pipeline_parallel .schedules import (GpipeSchedule ,
19
- OverlapFriendlyPipeDreamSchedule ,
20
- PipeDreamFlush , InferenceSchedule )
18
+ from alpa .pipeline_parallel .schedules import create_pipeline_schedule
21
19
from alpa .pipeline_parallel .computation import (
22
20
create_donation_mapping , generate_computations_from_modules ,
23
21
generate_sharded_xla_computations ,
38
36
from alpa .shard_parallel .manual_sharding import (ManualShardingOption ,
39
37
ParsedManualShardingOption ,
40
38
get_flatten_axis_resources ,
39
+ get_intermediate_parsed_spec ,
41
40
parsed_spec_to_opsharding )
42
41
from alpa .util import (get_var_mapping , trace_jaxpr_with_micro_batch ,
43
42
OrderedSet , GradFuncTransformContext )
@@ -198,31 +197,14 @@ def compile_pipeshard_executable_internal(
198
197
debug_compilation_time ("apply grad" )
199
198
200
199
# Generate pipeline schedule and placement
201
- dependency = gen_dependency_with_stages (jax_pipeline_stages ,
202
- sliced_apply_grad_stages )
203
- if pipeline_schedule == "gpipe" :
204
- schedule = GpipeSchedule (dependency = dependency ,
205
- meshes = sliced_virtual_meshes ,
206
- apply_grad_placement = apply_grad_placement ,
207
- num_batch = num_microbatch )
208
- elif pipeline_schedule == "1f1b" :
209
- schedule = PipeDreamFlush (dependency = dependency ,
210
- meshes = sliced_virtual_meshes ,
211
- apply_grad_placement = apply_grad_placement ,
212
- num_batch = num_microbatch )
213
- elif pipeline_schedule == "inference" :
214
- schedule = InferenceSchedule (dependency = dependency ,
215
- meshes = sliced_virtual_meshes ,
216
- apply_grad_placement = apply_grad_placement ,
217
- num_batch = num_microbatch )
218
- elif pipeline_schedule == "1f1b_overlap_friendly" :
219
- schedule = OverlapFriendlyPipeDreamSchedule (
220
- dependency = dependency ,
221
- meshes = sliced_virtual_meshes ,
222
- apply_grad_placement = apply_grad_placement ,
223
- num_batch = num_microbatch )
224
- else :
225
- raise ValueError (f"Invalid schedule: { pipeline_schedule } " )
200
+ dependency , fwd_intermediates = gen_dependency_with_stages (
201
+ jax_pipeline_stages , num_meshes , sliced_apply_grad_stages )
202
+ schedule = create_pipeline_schedule (
203
+ pipeline_schedule ,
204
+ dependency = dependency ,
205
+ meshes = sliced_virtual_meshes ,
206
+ apply_grad_placement = apply_grad_placement ,
207
+ num_batch = num_microbatch )
226
208
227
209
# Forcibly set the sharding specs of global invars and outvars.
228
210
# FIXME(yonghao): the invar can appear on multiple meshes and thus different
@@ -245,7 +227,7 @@ def compile_pipeshard_executable_internal(
245
227
output_sharding_dicts ) = get_manual_input_output_sharding_specs (
246
228
jax_all_stages , manual_stage_option .submesh_logical_shapes ,
247
229
parsed_manual_sharding_option , global_invars , global_outvars ,
248
- schedule .stage_mesh_mapping )
230
+ schedule .stage_mesh_mapping , fwd_intermediates )
249
231
else :
250
232
input_sharding_dicts = [input_sharding_dict ] * num_meshes
251
233
output_sharding_dicts = [output_sharding_dict ] * num_meshes
@@ -353,7 +335,7 @@ def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr,
353
335
354
336
def get_manual_input_output_sharding_specs (stages , mesh_shapes , ms_option ,
355
337
global_invars , global_outvars ,
356
- stage_to_mesh ):
338
+ stage_to_mesh , fwd_intermediates ):
357
339
"""
358
340
Split user assigned input and output PartitionSpec into sharding specs for
359
341
each pipeline stage.
@@ -363,19 +345,33 @@ def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option,
363
345
var_to_pspec = {}
364
346
handle_invar = False
365
347
handle_outvar = False
348
+ # Add global input and output's parsed partition spec.
366
349
if ms_option .in_parsed_pspec is not None :
367
350
var_to_pspec .update (dict (zip (global_invars , ms_option .in_parsed_pspec )))
368
351
handle_invar = True
369
352
if ms_option .out_parsed_pspec is not None :
370
353
var_to_pspec .update (
371
354
dict (zip (global_outvars , ms_option .out_parsed_pspec )))
372
355
handle_outvar = True
356
+ # Add pipeline intermediate's parsed partition spec.
357
+ intermediate_to_pspec = {}
358
+ if ms_option .pipeline_intermediate_axes is not None :
359
+ for v in fwd_intermediates :
360
+ # TODO: This is a simple heuristic: we simply replicate 1d tensors.
361
+ if len (v .aval .shape ) <= 1 :
362
+ continue
363
+ intermediate_to_pspec [v ] = get_intermediate_parsed_spec (
364
+ ms_option .pipeline_intermediate_axes , len (v .aval .shape ))
365
+
373
366
submesh_axis_names = ms_option .submesh_axis_names
374
367
if submesh_axis_names is None :
375
368
submesh_axis_names = [ms_option .mesh_axis_names ] * len (mesh_shapes )
376
369
377
370
def get_vars_to_sharding_specs (variables , mesh_shape , mesh_axis_names ):
378
- parsed_specs = [var_to_pspec [v ] for v in variables ]
371
+ parsed_specs = [
372
+ (var_to_pspec [v ] if v in var_to_pspec else intermediate_to_pspec [v ])
373
+ for v in variables
374
+ ]
379
375
avals = [v .aval for v in variables ]
380
376
var_op_shardings = parsed_spec_to_opsharding (parsed_specs , avals ,
381
377
mesh_shape ,
@@ -398,8 +394,13 @@ def get_vars_to_sharding_specs(variables, mesh_shape, mesh_axis_names):
398
394
# invars
399
395
if handle_invar :
400
396
invar_in_global = [var for var in stage .invars if var in invar_set ]
397
+ # add intermediate vars
398
+ intermediate_var = [
399
+ var for var in stage .invars if var in intermediate_to_pspec
400
+ ]
401
+ invars = invar_in_global + intermediate_var
401
402
stage_invar_shardings = get_vars_to_sharding_specs (
402
- invar_in_global , mesh_shape , mesh_axis_names )
403
+ invars , mesh_shape , mesh_axis_names )
403
404
else :
404
405
stage_invar_shardings = {}
405
406
# outvars
@@ -458,13 +459,17 @@ def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes,
458
459
compile_intermediate = [None ] * num_meshes
459
460
total_flops = 0
460
461
for mesh_idx in range (num_meshes ):
461
- input_sharding_dict = input_sharding_dicts [mesh_idx ]
462
- output_sharding_dict = output_sharding_dicts [mesh_idx ]
463
462
virtual_mesh = virtual_meshes [mesh_idx ]
464
463
logical_mesh = virtual_mesh .get_logical_mesh (
465
464
logical_mesh_shapes [mesh_idx ])
466
465
autosharding_option = dataclasses .replace (
467
466
default_as_option , ** autosharding_option_dicts [mesh_idx ])
467
+
468
+ # Predefined shardings. stage_input_sharding should have shardings for
469
+ # all parameters, while the sharding dict can have only a portion of
470
+ # all parameters.
471
+ input_sharding_dict = input_sharding_dicts [mesh_idx ]
472
+ output_sharding_dict = output_sharding_dicts [mesh_idx ]
468
473
stage_input_sharding = stage_input_shardings [mesh_idx ]
469
474
470
475
# Setup dummy stages
0 commit comments