From 460751b6574359b129a00fb7d39934e932fe6c0f Mon Sep 17 00:00:00 2001 From: Mchoi-git Date: Sun, 21 Jan 2024 14:41:30 -0500 Subject: [PATCH 1/4] Added strategies and cached state management --- flex_model/distributed/cached_state.py | 84 ++++++++++++++++++++++++++ flex_model/distributed/strategies.py | 25 ++++++++ 2 files changed, 109 insertions(+) create mode 100644 flex_model/distributed/cached_state.py diff --git a/flex_model/distributed/cached_state.py b/flex_model/distributed/cached_state.py new file mode 100644 index 0000000..de90492 --- /dev/null +++ b/flex_model/distributed/cached_state.py @@ -0,0 +1,84 @@ +from torch import Tensor +import torch.nn as nn + +from flex_model.distributed.distributed_state import _ParallelStateAPI +from flex_model.distributed.stratgies import ( + SaveCtxStrategy, + TrainableModulesStrategy, +) + + +def pipeline_sync(obj_to_sync, fmps: _ParallelStateAPI): + raise NotImplementedError + + +class SaveContext: + def __init__( + self, + fmps: _ParallelStateAPI, + strategy: SaveCtxStrategy = SaveCtxStrategy.REPLICATE_ALL, + ): + self.strategy = strategy + self.fmps = fmps + + def save(self, *tensors: Tensor): + for t in tensors: + assert isinstance(t, Tensor), ( + "The `save` function should only be used on tensor instances. ", + "Non-tensor data can be saved using `save_ctx.data = item.", + ) + + # Don't cache tensors depending on strategy. + # dp_rank = self.fmps.get_data_parallel_rank() + # pp_rank = self.fmps.get_pipeline_parallel_rank() + # tp_rank = self.fmps.get_tensor_parallel_rank() + + # Check if this rank should actually cache tensors. + do_cache_tensors = True + if self.strategy != SaveCtxStrategy.REPLICATE_ALL: + if self.strategy == SaveCtxStrategy.REPLICATE_PP: + raise NotImplementedError( + "Passing save context along pipeline ranks is not " + "implemented." + ) + + # Cache tensors if necessary. + if do_cache_tensors is True: + self.cached_tensors = tensors + + def get_cached_tensors(self): + return self.cached_tensors + + def sync(self): + # Check if we need to sync + do_sync = False + if self.strategy == SaveCtxStrategy.REPLICATE_ALL: + do_sync = True + + # Sync if necessary. + if do_sync is True: + pipeline_sync(self._modules, self.fmps) + do_sync = False + + +class TrainableModules(nn.ModuleDict): + def __init__( + self, + fmps: _ParallelStateAPI, + *args, + strategy: TrainableModulesStrategy = TrainableModulesStrategy.REPLICATE_ALL, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.strategy = strategy + self.fmps = fmps + + def sync(self): + # Check if we need to sync + do_sync = False + if self.strategy == TrainableModulesStrategy.REPLICATE_ALL: + do_sync = True + + # Sync if necessary. + if do_sync is True: + pipeline_sync(self._modules, self.fmps) diff --git a/flex_model/distributed/strategies.py b/flex_model/distributed/strategies.py index c232789..7606eca 100644 --- a/flex_model/distributed/strategies.py +++ b/flex_model/distributed/strategies.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Callable, Dict, List, Optional from torch import Tensor @@ -279,3 +280,27 @@ def is_valid(fn) -> bool: return cls(user_func) else: raise Exception("Provided editing function is not valid") + + +class SaveCtxStrategy(Enum): + """ + Defines a strategy which handles allocating `save_ctx` data. Default + stratgy is to simply replicate all `save_ctx` across all workers. + """ + + REPLICATE_ALL = 1 + REPLICATE_DP = 2 + REPLICATE_PP = 3 + REPLICATE_TP = 4 + + +class TrainableModulesStrategy(Enum): + """ + Defines a strategy which handles allocating `trainable_modules`. + Default stratgy is to simply replicate across all workers. + """ + + REPLICATE_ALL = 1 + REPLICATE_DP = 2 + REPLICATE_PP = 3 + REPLICATE_TP = 4 From a7ce5a0809bdf51b580c69df9fe2ba5ca2ffc3d3 Mon Sep 17 00:00:00 2001 From: Mchoi-git Date: Mon, 22 Jan 2024 07:10:18 -0500 Subject: [PATCH 2/4] added cache sync --- flex_model/distributed/cached_state.py | 326 +++++++++++++++++++++++++ 1 file changed, 326 insertions(+) diff --git a/flex_model/distributed/cached_state.py b/flex_model/distributed/cached_state.py index de90492..2f15fb6 100644 --- a/flex_model/distributed/cached_state.py +++ b/flex_model/distributed/cached_state.py @@ -1,12 +1,338 @@ +from typing import Dict, Union, Optional, Tuple, List +import logging + +import torch from torch import Tensor import torch.nn as nn from flex_model.distributed.distributed_state import _ParallelStateAPI +from flex_model.distributed.mappings import _log_shape from flex_model.distributed.stratgies import ( SaveCtxStrategy, TrainableModulesStrategy, ) +logger = logging.getLogger(__name__) + + +def _group_by_dtype( + tensor_dict: Dict[str, Tensor] +) -> Dict[torch.dtype, Dict[str, Tensor]]: + dtypes = [torch.float32, torch.float16, torch.bfloat16] + dtype_groups: Dict[torch.dtype, Dict[str, Tensor]] = { + dtype: {} for dtype in dtypes + } + + for name, tensor in tensor_dict.items(): + assert tensor.dtype in dtype_groups, ( + f"Tensor with dtype: {tensor.dtype} is not supported for " + f"gathering across PP ranks." + ) + dtype_groups[tensor.dtype][name] = tensor + + return dtype_groups + + +# Tensor buffer metadata type. +_TBUF_META = Dict[ + str, + Union[ + int, + torch.dtype, + Dict[str, Tuple[int, int]], + Dict[str, torch.Size], + ], +] + + +def _make_flat_buffer( + fmps: _ParallelStateAPI, + tensor_dict: Dict[str, Tensor], +) -> Tuple[Optional[Tensor], Optional[_TBUF_META]]: + tensors = [] + name_to_index_map = {} + name_to_shape_map = {} + curr_idx = 0 + for name, tensor in tensor_dict.items(): + shape = tensor.shape + numel = tensor.numel() + tensors.append(tensor.flatten()) + + name_to_index_map[name] = (curr_idx, curr_idx + numel) + name_to_shape_map[name] = shape + + curr_idx += numel + + if len(tensors) == 0: + return None, None + + tensor_buffer = torch.cat(tensors) + + meta: _TBUF_META = { + "buffer_rank": fmps.get_pipeline_parallel_rank(), + "buffer_size": tensor_buffer.numel(), + "buffer_dtype": tensor_buffer.dtype, + "name_to_index_map": name_to_index_map, + "name_to_shape_map": name_to_shape_map, + } + + return tensor_buffer, meta + + +def _broadcast_pipeline_parallel( + fmps: _ParallelStateAPI, + tbuf_groups: Dict[torch.dtype, Optional[Tensor]], + all_metadata_groups: List[Optional[Dict[torch.dtype, _TBUF_META]]], +) -> Dict[str, Tensor]: + rank = fmps.get_pipeline_parallel_rank() + + # Setup collections for communication + def _empty_groups() -> Dict[torch.dtype, List[Union[Tensor, int]]]: + return {dtype: [] for dtype in tbuf_groups.keys()} + + recv_tbuf_groups = _empty_groups() + recv_rank_groups = _empty_groups() + send_tbuf_groups = _empty_groups() + send_rank_groups = _empty_groups() + + # Construct recv tensors and src ranks. + for metadata_groups in all_metadata_groups: + # Skip if the rank has no tbufs to recv for any dtype. + if metadata_groups is None: + continue + + for dtype, metadata in metadata_groups.items(): + # Skip if there's no tbuf to recv for the dtype or the source + # rank is 0 (rank0 never sends). + if metadata is None: + continue + + buffer_rank = metadata["buffer_rank"] + buffer_size = metadata["buffer_size"] + buffer_dtype = metadata["buffer_dtype"] + assert ( + buffer_dtype == dtype + ), f"Dtype mismatch: {buffer_dtype} and {dtype}" + + tbuf = torch.empty((buffer_size,), dtype=buffer_dtype) + src_rank = buffer_rank + recv_tbuf_groups[dtype].append(tbuf) + recv_rank_groups[dtype].append(src_rank) + + logger.debug( + f"Rank{rank}: Constructed recv - " + f"({tbuf.numel()}) [{src_rank}] -> [{fmps.get_rank()}]" + ) + + # Construct send tensors and dst ranks. + for dtype, tbuf in tbuf_groups.items(): + # Skip if there's no tbuf to send for the dtype. + if tbuf is None: + continue + + # Send dst always rank0. + for r in fmps.get_world_size(): + send_tbuf_groups[dtype].append(tbuf) + send_rank_groups[dtype].append(r) + + logger.debug( + f"Rank{rank}: Constructed send - " + f"({tbuf.numel()}) [{rank}] -> [{r}]" + ) + + def _set_device(_buffer_list, device): + return [_buffer.to(device) for _buffer in _buffer_list] + + # Batched communication across all dtype groups. + all_recv_tbufs = [] + all_recv_ranks = [] + all_send_tbufs = [] + all_send_ranks = [] + for dtype in tbuf_groups.keys(): + recv_tbufs = _set_device( + recv_tbuf_groups[dtype], device=torch.cuda.current_device() + ) + send_tbufs = _set_device( + send_tbuf_groups[dtype], device=torch.cuda.current_device() + ) + all_recv_tbufs.extend(recv_tbufs) + all_recv_ranks.extend(recv_rank_groups[dtype]) + all_send_tbufs.extend(send_tbufs) + all_send_ranks.extend(send_rank_groups[dtype]) + + batch_isend_irecv_pipeline_parallel( + fmps, + all_recv_tbufs, + all_recv_ranks, + all_send_tbufs, + all_send_ranks, + ) + all_recv_tbufs = _set_device(all_recv_tbufs, device="cpu") + all_send_tbufs = _set_device(all_send_tbufs, device="cpu") + + # Unshard each tbuf into individual tensors. + output_tensor_dict: Dict[str, Tensor] = {} + if rank == 0: + + def _reshard_tbuf(meta, tbuf): + for name, (start, end) in meta["name_to_index_map"].items(): + shape = meta["name_to_shape_map"][name] + output_tensor_dict[name] = tbuf[start:end].reshape(shape) + + # Add rank0 local tbufs. + for dtype, tbuf in tbuf_groups.items(): + meta = all_metadata_groups[0][dtype] + if meta is not None: + _reshard_tbuf(meta, tbuf) + + # Add gathered tbufs. + for recv_tbuf, recv_r in zip(all_recv_tbufs, all_recv_ranks): + dtype = recv_tbuf.dtype + meta = all_metadata_groups[recv_r][dtype] + + buf_rank = meta["buffer_rank"] + buf_dtype = meta["buffer_dtype"] + assert ( + buf_dtype == dtype + ), f"Dtype mismatch: {buf_dtype} and {dtype}" + assert buf_rank == recv_r, f"Rank mismatch: {buf_rank} and {recv_r}" + + _reshard_tbuf(meta, recv_tbuf) + + return output_tensor_dict + + +def batch_isend_irecv_pipeline_parallel( + fmps: _ParallelStateAPI, + recv_tensors: List[Tensor], + recv_from_ranks: List[int], + send_tensors: List[Tensor], + send_to_ranks: List[int], +) -> None: + """Run batched peer-to-peer communications. + + :param List[Tensor] recv_tensors: Tensors to receive. + :param List[int] recv_from_ranks: Ranks to receive from. + :param List[Tensor] send_tensors: Tensors to send. + :param List[int] send_to_ranks: Ranks to send to. + """ + rank = fmps.get_pipeline_parallel_rank() + group = fmps.get_pipeline_parallel_group() + + assert len(recv_tensors) == len(recv_from_ranks), ( + f"Mistmatch in recv tensors({len(recv_tensors)}) and " + f"recv ranks({len(recv_from_ranks)})" + ) + assert len(send_tensors) == len(send_to_ranks), ( + f"Mistmatch in send tensors({len(send_tensors)}) and " + f"send ranks({len(send_to_ranks)})" + ) + + p2p_ops = [] + for recv_t, recv_r in zip(recv_tensors, recv_from_ranks): + op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_t, + peer=recv_r, + group=group, + ) + p2p_ops.append(op) + + logger.debug(f"Rank{rank}: P2POp (irecv) [{rank}] <- [{recv_r}]") + + for send_t, send_r in zip(send_tensors, send_to_ranks): + op = torch.distributed.P2POp( + torch.distributed.isend, + send_t, + peer=send_r, + group=group, + ) + p2p_ops.append(op) + + logger.debug(f"Rank{rank}: P2POp (isend) [{rank}] -> [{send_r}]") + + if len(p2p_ops) == 0: + return + + logger.debug(f"Rank{rank}: Launching P2POps") + + reqs = torch.distributed.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + def _gen_debug_msg(t_list): + return ", ".join([f"({t.numel()}, {t.dtype})" for t in t_list]) + + logger.debug( + f"Rank{rank}: Received buffers - [{_gen_debug_msg(recv_tensors)}]" + ) + logger.debug(f"Rank{rank}: Sent buffers - [{_gen_debug_msg(send_tensors)}]") + + # TODO: Remove after verification that no race cond. occurs. + torch.cuda.synchronize() + + +def gather_pipeline_parallel_tensor_dicts( + fmps: _ParallelStateAPI, + tensor_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Gather groups of tensors from ranks of the pipeline group to pipeline rank0. + + Note: Assumes input tensors are on CPU and placed output tensors on CPU. + - This behaviour is subject to change depending on various optimizations. + + :param _ParallelStateAPI fmps: FlexModel parallel state handle. + :param tensor_dict: Some python object that can be pickled. May contain tensors. + :type tensor_dict Dict[str, Tensor]: + + :returns: A collection of the objects sent from all pipeline paralel group ranks. + :rtype: Dict[str, Tensor] + """ + in_shapes = [] + for tensor in tensor_dict.values(): + in_shapes.append(tensor.shape) + + world_size = fmps.get_pipeline_parallel_world_size() + rank = fmps.get_pipeline_parallel_rank() + group = fmps.get_pipeline_parallel_group() + + tensor_dict_groups = _group_by_dtype(tensor_dict) + + # Convert tensor dicts into flattened buffers with metadata. + tbuf_groups = {} + metadata_groups = {} + for dtype, tensor_dict in tensor_dict_groups.items(): + tbuf, meta = _make_flat_buffer(fmps, tensor_dict) + + tbuf_groups[dtype] = tbuf + metadata_groups[dtype] = meta + + # Gather metadata on rank 0 to setup recv tensors. + all_metadata_groups: List[Optional[Dict[torch.dtype, _TBUF_META]]] = [ + None for _ in range(world_size) + ] + torch.distributed.gather_object( + metadata_groups, + all_metadata_groups if rank == 0 else None, + dst=0, + group=group, + ) + + # Communicate. + output_tensor_dict = _broadcast_pipeline_parallel( + fmps, tbuf_groups, all_metadata_groups + ) + + for in_shape, out_tensor in zip(in_shapes, output_tensor_dict.values()): + _log_shape( + rank, + "gather_pipeline_parallel_tensor_dicts", + in_shape, + out_tensor.shape, + ) + + return output_tensor_dict + def pipeline_sync(obj_to_sync, fmps: _ParallelStateAPI): raise NotImplementedError From 7438f49414cc43fd4b1a0b946dedd9d44f5cc28d Mon Sep 17 00:00:00 2001 From: Mchoi-git Date: Thu, 1 Feb 2024 10:00:43 -0500 Subject: [PATCH 3/4] generalized pp gather --- flex_model/distributed/cached_state.py | 86 +++++--------------------- flex_model/distributed/strategies.py | 25 -------- 2 files changed, 15 insertions(+), 96 deletions(-) diff --git a/flex_model/distributed/cached_state.py b/flex_model/distributed/cached_state.py index 2f15fb6..d88f53c 100644 --- a/flex_model/distributed/cached_state.py +++ b/flex_model/distributed/cached_state.py @@ -3,27 +3,24 @@ import torch from torch import Tensor -import torch.nn as nn from flex_model.distributed.distributed_state import _ParallelStateAPI from flex_model.distributed.mappings import _log_shape from flex_model.distributed.stratgies import ( SaveCtxStrategy, - TrainableModulesStrategy, ) logger = logging.getLogger(__name__) def _group_by_dtype( - tensor_dict: Dict[str, Tensor] + tensors: Dict[str, Tensor] ) -> Dict[torch.dtype, Dict[str, Tensor]]: dtypes = [torch.float32, torch.float16, torch.bfloat16] - dtype_groups: Dict[torch.dtype, Dict[str, Tensor]] = { - dtype: {} for dtype in dtypes - } - for name, tensor in tensor_dict.items(): + dtype_groups = {dtype: {} for dtype in dtypes} + + for name, tensor in tensors.items(): assert tensor.dtype in dtype_groups, ( f"Tensor with dtype: {tensor.dtype} is not supported for " f"gathering across PP ranks." @@ -307,14 +304,13 @@ def gather_pipeline_parallel_tensor_dicts( tbuf_groups[dtype] = tbuf metadata_groups[dtype] = meta - # Gather metadata on rank 0 to setup recv tensors. + # Gather metadata on all ranks to setup recv tensors. all_metadata_groups: List[Optional[Dict[torch.dtype, _TBUF_META]]] = [ None for _ in range(world_size) ] - torch.distributed.gather_object( + torch.distributed.all_gather_object( + all_metadata_groups, metadata_groups, - all_metadata_groups if rank == 0 else None, - dst=0, group=group, ) @@ -334,8 +330,8 @@ def gather_pipeline_parallel_tensor_dicts( return output_tensor_dict -def pipeline_sync(obj_to_sync, fmps: _ParallelStateAPI): - raise NotImplementedError +def _pipeline_sync(tensor_dict: Dict[str, Tensor], fmps: _ParallelStateAPI): + gather_pipeline_parallel_tensor_dicts(tensor_dict, fmps) class SaveContext: @@ -346,65 +342,13 @@ def __init__( ): self.strategy = strategy self.fmps = fmps + self.cached_tensors = {} - def save(self, *tensors: Tensor): - for t in tensors: - assert isinstance(t, Tensor), ( - "The `save` function should only be used on tensor instances. ", - "Non-tensor data can be saved using `save_ctx.data = item.", - ) - - # Don't cache tensors depending on strategy. - # dp_rank = self.fmps.get_data_parallel_rank() - # pp_rank = self.fmps.get_pipeline_parallel_rank() - # tp_rank = self.fmps.get_tensor_parallel_rank() + def save(self, name: str, tensor: Tensor): + self.cached_tensors[name] = tensor - # Check if this rank should actually cache tensors. - do_cache_tensors = True - if self.strategy != SaveCtxStrategy.REPLICATE_ALL: - if self.strategy == SaveCtxStrategy.REPLICATE_PP: - raise NotImplementedError( - "Passing save context along pipeline ranks is not " - "implemented." - ) - - # Cache tensors if necessary. - if do_cache_tensors is True: - self.cached_tensors = tensors - - def get_cached_tensors(self): - return self.cached_tensors - - def sync(self): - # Check if we need to sync - do_sync = False - if self.strategy == SaveCtxStrategy.REPLICATE_ALL: - do_sync = True - - # Sync if necessary. - if do_sync is True: - pipeline_sync(self._modules, self.fmps) - do_sync = False - - -class TrainableModules(nn.ModuleDict): - def __init__( - self, - fmps: _ParallelStateAPI, - *args, - strategy: TrainableModulesStrategy = TrainableModulesStrategy.REPLICATE_ALL, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.strategy = strategy - self.fmps = fmps + def get_tensor(self, name: str) -> Tensor: + return self.cached_tensors[name] def sync(self): - # Check if we need to sync - do_sync = False - if self.strategy == TrainableModulesStrategy.REPLICATE_ALL: - do_sync = True - - # Sync if necessary. - if do_sync is True: - pipeline_sync(self._modules, self.fmps) + _pipeline_sync(self.cached_tensors, self.fmps) diff --git a/flex_model/distributed/strategies.py b/flex_model/distributed/strategies.py index 7606eca..c232789 100644 --- a/flex_model/distributed/strategies.py +++ b/flex_model/distributed/strategies.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import Any, Callable, Dict, List, Optional from torch import Tensor @@ -280,27 +279,3 @@ def is_valid(fn) -> bool: return cls(user_func) else: raise Exception("Provided editing function is not valid") - - -class SaveCtxStrategy(Enum): - """ - Defines a strategy which handles allocating `save_ctx` data. Default - stratgy is to simply replicate all `save_ctx` across all workers. - """ - - REPLICATE_ALL = 1 - REPLICATE_DP = 2 - REPLICATE_PP = 3 - REPLICATE_TP = 4 - - -class TrainableModulesStrategy(Enum): - """ - Defines a strategy which handles allocating `trainable_modules`. - Default stratgy is to simply replicate across all workers. - """ - - REPLICATE_ALL = 1 - REPLICATE_DP = 2 - REPLICATE_PP = 3 - REPLICATE_TP = 4 From b0000004c08adac1f6f3e7c3d8f0ce4083afac4a Mon Sep 17 00:00:00 2001 From: Mchoi-git Date: Wed, 14 Feb 2024 13:53:38 -0500 Subject: [PATCH 4/4] fixed import --- _test/multi_gpu/distributed/test_mappings.py | 40 ++++++++++ flex_model/distributed/__init__.py | 4 + flex_model/distributed/cached_state.py | 82 +++++++++++--------- 3 files changed, 89 insertions(+), 37 deletions(-) diff --git a/_test/multi_gpu/distributed/test_mappings.py b/_test/multi_gpu/distributed/test_mappings.py index 916e8ac..35dbee8 100644 --- a/_test/multi_gpu/distributed/test_mappings.py +++ b/_test/multi_gpu/distributed/test_mappings.py @@ -4,6 +4,7 @@ import torch.nn as nn import flex_model.distributed as fm_dist +from flex_model.distributed import sync_pipeline_parallel from _test.multi_gpu.registry import SlurmJobResourceSpec, make_test_registry import _test.multi_gpu.testing_utils as utils @@ -219,3 +220,42 @@ def test_gather_pipeline_parallel_dtypes(): tensor, ) utils.print_success("test_gather_pipeline_parallel_dtypes") + + +@register_mappings_test +def test_pipeline_sync(): + utils.init_process_group() + + model = nn.Linear(2, 4) + fmps = fm_dist.initialize_distributed_state(model, 1, _NUM_GPUS, 1) + + rank = fmps.get_pipeline_parallel_rank() + world_size = fmps.get_pipeline_parallel_world_size() + + tensor_dict = {} + tensors_per_rank = 4 + dtypes = [torch.float32, torch.float16, torch.bfloat16] + for dtype in dtypes: + for i in range(tensors_per_rank): + tensor_idx = rank * tensors_per_rank + i + name = f"tensor_{tensor_idx}_{dtype}" + tensor = torch.ones((1,), dtype=dtype) * tensor_idx + tensor_dict[name] = tensor + + result = sync_pipeline_parallel(fmps, tensor_dict) + + if torch.distributed.get_rank() == 1: + breakpoint() + torch.distributed.barrier() + + assert len(result) == tensors_per_rank * world_size * len(dtypes) + for dtype in dtypes: + for i in range(tensors_per_rank): + tensor_idx = rank * tensors_per_rank + i + name = f"tensor_{tensor_idx}_{dtype}" + tensor = torch.ones((1,), dtype=dtype) * tensor_idx + assert torch.equal( + result[name], + tensor, + ) + utils.print_success("test_gather_pipeline_parallel_dtypes") diff --git a/flex_model/distributed/__init__.py b/flex_model/distributed/__init__.py index d11ee96..6b309e9 100644 --- a/flex_model/distributed/__init__.py +++ b/flex_model/distributed/__init__.py @@ -24,3 +24,7 @@ BaseFunctionStrategy, NonValidatedFunctionStrategy, ) +from .cached_state import ( + sync_pipeline_parallel, + SaveContext, +) diff --git a/flex_model/distributed/cached_state.py b/flex_model/distributed/cached_state.py index d88f53c..d3211a5 100644 --- a/flex_model/distributed/cached_state.py +++ b/flex_model/distributed/cached_state.py @@ -3,16 +3,19 @@ import torch from torch import Tensor +import torch.distributed as dist from flex_model.distributed.distributed_state import _ParallelStateAPI -from flex_model.distributed.mappings import _log_shape -from flex_model.distributed.stratgies import ( - SaveCtxStrategy, -) logger = logging.getLogger(__name__) +def _log_shape(rank, fn_name, in_shape, out_shape): + logger.debug( + f"Local rank{rank} - {fn_name} | Input: {in_shape} -> {out_shape}" + ) + + def _group_by_dtype( tensors: Dict[str, Tensor] ) -> Dict[torch.dtype, Dict[str, Tensor]]: @@ -118,7 +121,7 @@ def _empty_groups() -> Dict[torch.dtype, List[Union[Tensor, int]]]: logger.debug( f"Rank{rank}: Constructed recv - " - f"({tbuf.numel()}) [{src_rank}] -> [{fmps.get_rank()}]" + f"({tbuf.numel()}) [{src_rank}] -> [{dist.get_rank()}]" ) # Construct send tensors and dst ranks. @@ -128,7 +131,7 @@ def _empty_groups() -> Dict[torch.dtype, List[Union[Tensor, int]]]: continue # Send dst always rank0. - for r in fmps.get_world_size(): + for r in range(fmps.get_pipeline_parallel_world_size()): send_tbuf_groups[dtype].append(tbuf) send_rank_groups[dtype].append(r) @@ -169,32 +172,29 @@ def _set_device(_buffer_list, device): # Unshard each tbuf into individual tensors. output_tensor_dict: Dict[str, Tensor] = {} - if rank == 0: - - def _reshard_tbuf(meta, tbuf): - for name, (start, end) in meta["name_to_index_map"].items(): - shape = meta["name_to_shape_map"][name] - output_tensor_dict[name] = tbuf[start:end].reshape(shape) - - # Add rank0 local tbufs. - for dtype, tbuf in tbuf_groups.items(): - meta = all_metadata_groups[0][dtype] - if meta is not None: - _reshard_tbuf(meta, tbuf) - - # Add gathered tbufs. - for recv_tbuf, recv_r in zip(all_recv_tbufs, all_recv_ranks): - dtype = recv_tbuf.dtype - meta = all_metadata_groups[recv_r][dtype] - - buf_rank = meta["buffer_rank"] - buf_dtype = meta["buffer_dtype"] - assert ( - buf_dtype == dtype - ), f"Dtype mismatch: {buf_dtype} and {dtype}" - assert buf_rank == recv_r, f"Rank mismatch: {buf_rank} and {recv_r}" - _reshard_tbuf(meta, recv_tbuf) + def _reshard_tbuf(meta, tbuf): + for name, (start, end) in meta["name_to_index_map"].items(): + shape = meta["name_to_shape_map"][name] + output_tensor_dict[name] = tbuf[start:end].reshape(shape) + + # Add rank0 local tbufs. + for dtype, tbuf in tbuf_groups.items(): + meta = all_metadata_groups[0][dtype] + if meta is not None: + _reshard_tbuf(meta, tbuf) + + # Add gathered tbufs. + for recv_tbuf, recv_r in zip(all_recv_tbufs, all_recv_ranks): + dtype = recv_tbuf.dtype + meta = all_metadata_groups[recv_r][dtype] + + buf_rank = meta["buffer_rank"] + buf_dtype = meta["buffer_dtype"] + assert buf_dtype == dtype, f"Dtype mismatch: {buf_dtype} and {dtype}" + assert buf_rank == recv_r, f"Rank mismatch: {buf_rank} and {recv_r}" + + _reshard_tbuf(meta, recv_tbuf) return output_tensor_dict @@ -269,7 +269,7 @@ def _gen_debug_msg(t_list): torch.cuda.synchronize() -def gather_pipeline_parallel_tensor_dicts( +def sync_pipeline_parallel( fmps: _ParallelStateAPI, tensor_dict: Dict[str, Tensor], ) -> Dict[str, Tensor]: @@ -330,17 +330,25 @@ def gather_pipeline_parallel_tensor_dicts( return output_tensor_dict -def _pipeline_sync(tensor_dict: Dict[str, Tensor], fmps: _ParallelStateAPI): - gather_pipeline_parallel_tensor_dicts(tensor_dict, fmps) +def sync_tensor_parallel( + fmps: _ParallelStateAPI, + tensor_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + pass + + +def sync_data_parallel( + fmps: _ParallelStateAPI, + tensor_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + pass class SaveContext: def __init__( self, fmps: _ParallelStateAPI, - strategy: SaveCtxStrategy = SaveCtxStrategy.REPLICATE_ALL, ): - self.strategy = strategy self.fmps = fmps self.cached_tensors = {} @@ -351,4 +359,4 @@ def get_tensor(self, name: str) -> Tensor: return self.cached_tensors[name] def sync(self): - _pipeline_sync(self.cached_tensors, self.fmps) + sync_pipeline_parallel(self.cached_tensors, self.fmps)