diff --git a/alpa/pipeline_parallel/cross_mesh_resharding.py b/alpa/pipeline_parallel/cross_mesh_resharding.py index 388ba8847..19c6c31df 100644 --- a/alpa/pipeline_parallel/cross_mesh_resharding.py +++ b/alpa/pipeline_parallel/cross_mesh_resharding.py @@ -5,8 +5,10 @@ import math import random import time -from typing import List, Any +from typing import Dict, List, Any, Sequence +from alpa.pipeline_parallel.schedules import PipelineSchedule +from jax.core import Var from jax.interpreters import pxla import numpy as np import ray @@ -682,7 +684,8 @@ class ReshardingTaskSpec: VirtualDistributedArray. """ - def __init__(self, src_array, dst_array, final_dst_spec): + def __init__(self, src_array: VirtualDistributedArray, + dst_array: VirtualDistributedArray, final_dst_spec): self.src = src_array self.dst = dst_array self._dst_tile_to_src_tiles_map = None @@ -949,7 +952,8 @@ class CrossMeshCommunicator: schedule (Any): the pipelining schedule for these stages. """ - def __init__(self, sharded_stages, schedule): + def __init__(self, sharded_stages: Sequence[XlaShardedPipelineComputation], + schedule: PipelineSchedule): if not isinstance(sharded_stages, list): raise RuntimeError("Require a list of stages.") for s in sharded_stages: @@ -1091,6 +1095,9 @@ def _create_resharding_specs(self): [{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh) ] + # We will grab the var from the stage where it is last an input, if any + # We will map it to the corresponding stage index where it is last seen + last_seen: Dict[Var, int] = {} # find stages that will communicate pairs = np.argwhere(deps > 0) for i in range(pairs.shape[0]): @@ -1116,29 +1123,48 @@ def _create_resharding_specs(self): out_sharding_specs = src_stage.output_sharding_specs in_sharding_specs = dst_stage.input_sharding_specs - # Make a ReshardSpec for each VirtualDistributedArray + # Make a ReshardingTaskSpec for each VirtualDistributedArray for var, out_var_index, in_var_index in zip(resharding_vars, out_var_indices, in_var_indices): - src_sharding_spec = out_sharding_specs[out_var_index] - dst_sharding_spec = in_sharding_specs[in_var_index] + if var in last_seen: + last_seen_stage_index = last_seen[var] + last_seen[var] = dst_stage_index + + last_seen_var_index = last_seen_stage.invars.index(var) + last_seen_sharding_spec = last_seen_stage.input_sharding_specs[ + last_seen_var_index] + + last_seen_stage = stages[last_seen_stage_index] + last_seen_mesh_index = stage_placements[ + last_seen_stage_index] + last_seen_mesh = meshes[last_seen_mesh_index] + final_src_array = VirtualDistributedArray( + device_mesh=last_seen_mesh, + aval=var.aval, + sharding_spec=last_seen_sharding_spec) + final_src_mesh_index = last_seen_mesh_index + else: + last_seen[var] = dst_stage_index + src_sharding_spec = out_sharding_specs[out_var_index] + final_src_array = VirtualDistributedArray( + device_mesh=src_mesh, + aval=var.aval, + sharding_spec=src_sharding_spec) + final_src_mesh_index = src_mesh_index + dst_sharding_spec = in_sharding_specs[in_var_index] final_dst_spec = dst_sharding_spec if global_config.resharding_mode == "send_recv": dst_sharding_spec = self._rewrite_allgather_spec( dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape) - - src_array = VirtualDistributedArray( - device_mesh=src_mesh, - aval=var.aval, - sharding_spec=src_sharding_spec) dst_array = VirtualDistributedArray( device_mesh=dst_mesh, aval=var.aval, sharding_spec=dst_sharding_spec) - task_spec = ReshardingTaskSpec(src_array, dst_array, + task_spec = ReshardingTaskSpec(final_src_array, dst_array, final_dst_spec) - self.resharding_specs[src_mesh_index][dst_mesh_index][ + self.resharding_specs[final_src_mesh_index][dst_mesh_index][ var] = task_spec def task_spec_iter(self): @@ -1425,7 +1451,8 @@ def _generate_broadcast_resharding_strategy_by_loads( return strategy @staticmethod - def _args_between(src_stage, dst_stage): + def _args_between(src_stage: XlaShardedPipelineComputation, + dst_stage: XlaShardedPipelineComputation): """Find the variable exchanged between stages.""" resharding_vars = [] src_indices = []