diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index 43707f3..09bfe08 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -14,6 +14,7 @@ """Experimental resharding API for elastic device sets.""" import base64 +import collections import json from typing import Any, Dict, Sequence @@ -92,9 +93,7 @@ def _get_resharding_plan( donate: bool, ) -> ReshardingPlanWrapper: """Returns a resharding plan for the given sharding task.""" - return ReshardingPlanWrapper( - avals, old_shardings, new_shardings, donate - ) + return ReshardingPlanWrapper(avals, old_shardings, new_shardings, donate) _get_resharding_plan_cached = lru_cache.lru_cache()(_get_resharding_plan) @@ -116,12 +115,11 @@ def reshard( (must be a tree prefix of `x`), representing the device(s) and sharding to which `x` should be sharded to. The result will be committed to the device(s) of the sharding. - donate: If `True`, donate all input arrays, which may reduce the - amount memory needed for resharding. Buffers donated to resharding should - not be reused. - may_alias: If `True`, may alias the input array with the output array. - May reduce the amount of memory needed for resharding. Not used at the - moment. + donate: If `True`, donate all input arrays, which may reduce the amount of + memory needed for resharding. Buffers donated to resharding should not be + reused. + may_alias: If `True`, may alias the input array with the output array. May + reduce the amount of memory needed for resharding. Not used at the moment. cache_resharding_plans: If `True`, uses a resharding plan cache to avoid recreating plans for the same resharding operation. May improve performance for use cases where the same resharding operation is done many @@ -137,43 +135,61 @@ def reshard( "reshard sharding", tree_def, sharding ) - jax_arrays = [] - jax_array_dst_shardings = [] - non_jax_arrays = [] - non_jax_array_dst_shardings = [] - for arr, dst_sharding in zip(flat_x, flat_sharding): + # We must split the arrays into two groups: + # 1. jax.Array + # 2. non jax.Array + # For jax.Array, we will use the ifrt client to get the resharding plan and + # execute it. + # These arrays must be further split into groups based on the device set of + # the sharding, since plugin programs only supports execution on the same + # device set. + # For non jax.Array, we will use jax.device_put to put the array to the + # destination devices. + # + # We need to track what index each array is in the original pytree, so we can + # put them back together in the right order. + array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []} + jax_arrays = collections.defaultdict(array_info_lambda) + non_jax_arrays = array_info_lambda() + for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)): if not isinstance(dst_sharding, jax.sharding.Sharding): raise ValueError("`sharding` must contain only `jax.sharding.Sharding`") if isinstance(arr, jax.Array): - jax_arrays.append(arr) - jax_array_dst_shardings.append(dst_sharding) + device_set = frozenset(arr.sharding.device_set) + jax_arrays[device_set]["arrays"].append(arr) + jax_arrays[device_set]["indices"].append(index) + jax_arrays[device_set]["dst_shardings"].append(dst_sharding) else: - non_jax_arrays.append(arr) - non_jax_array_dst_shardings.append(dst_sharding) - - if non_jax_arrays: - non_jax_arrays = jax.device_put(non_jax_arrays, non_jax_array_dst_shardings) + non_jax_arrays["arrays"].append(arr) + non_jax_arrays["indices"].append(index) + non_jax_arrays["dst_shardings"].append(dst_sharding) + + if non_jax_arrays["arrays"]: + non_jax_arrays["arrays"] = jax.device_put( + non_jax_arrays["arrays"], + non_jax_arrays["dst_shardings"], + donate=donate, + may_alias=may_alias, + ) - if jax_arrays: + for array_info in jax_arrays.values(): get_resharding_plan_func = ( _get_resharding_plan_cached if cache_resharding_plans else _get_resharding_plan ) - jax_arrays = get_resharding_plan_func( - tuple(arr.aval for arr in jax_arrays), - tuple(arr.sharding for arr in jax_arrays), - tuple(jax_array_dst_shardings), + array_info["arrays"] = get_resharding_plan_func( + tuple(arr.aval for arr in array_info["arrays"]), + tuple(arr.sharding for arr in array_info["arrays"]), + tuple(array_info["dst_shardings"]), donate, - ).execute(tuple(jax_arrays)) + ).execute(tuple(array_info["arrays"])) - result = [] - jax_iter = iter(jax_arrays) - non_jax_iter = iter(non_jax_arrays) + result = [None] * len(flat_x) + for arr, idx in zip(non_jax_arrays["arrays"], non_jax_arrays["indices"]): + result[idx] = arr + for array_info in jax_arrays.values(): + for arr, idx in zip(array_info["arrays"], array_info["indices"]): + result[idx] = arr - for arr in flat_x: - if isinstance(arr, jax.Array): - result.append(next(jax_iter)) - else: - result.append(next(non_jax_iter)) return jax.tree.unflatten(tree_def, result)