Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 51 additions & 35 deletions pathwaysutils/experimental/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Experimental resharding API for elastic device sets."""

import base64
import collections
import json
from typing import Any, Dict, Sequence

Expand Down Expand Up @@ -92,9 +93,7 @@ def _get_resharding_plan(
donate: bool,
) -> ReshardingPlanWrapper:
"""Returns a resharding plan for the given sharding task."""
return ReshardingPlanWrapper(
avals, old_shardings, new_shardings, donate
)
return ReshardingPlanWrapper(avals, old_shardings, new_shardings, donate)


_get_resharding_plan_cached = lru_cache.lru_cache()(_get_resharding_plan)
Expand All @@ -116,12 +115,11 @@ def reshard(
(must be a tree prefix of `x`), representing the device(s) and sharding to
which `x` should be sharded to. The result will be committed to the
device(s) of the sharding.
donate: If `True`, donate all input arrays, which may reduce the
amount memory needed for resharding. Buffers donated to resharding should
not be reused.
may_alias: If `True`, may alias the input array with the output array.
May reduce the amount of memory needed for resharding. Not used at the
moment.
donate: If `True`, donate all input arrays, which may reduce the amount of
memory needed for resharding. Buffers donated to resharding should not be
reused.
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
recreating plans for the same resharding operation. May improve
performance for use cases where the same resharding operation is done many
Expand All @@ -137,43 +135,61 @@ def reshard(
"reshard sharding", tree_def, sharding
)

jax_arrays = []
jax_array_dst_shardings = []
non_jax_arrays = []
non_jax_array_dst_shardings = []
for arr, dst_sharding in zip(flat_x, flat_sharding):
# We must split the arrays into two groups:
# 1. jax.Array
# 2. non jax.Array
# For jax.Array, we will use the ifrt client to get the resharding plan and
# execute it.
# These arrays must be further split into groups based on the device set of
# the sharding, since plugin programs only supports execution on the same
# device set.
# For non jax.Array, we will use jax.device_put to put the array to the
# destination devices.
#
# We need to track what index each array is in the original pytree, so we can
# put them back together in the right order.
array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []}
jax_arrays = collections.defaultdict(array_info_lambda)
non_jax_arrays = array_info_lambda()
for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)):
if not isinstance(dst_sharding, jax.sharding.Sharding):
raise ValueError("`sharding` must contain only `jax.sharding.Sharding`")
if isinstance(arr, jax.Array):
jax_arrays.append(arr)
jax_array_dst_shardings.append(dst_sharding)
device_set = frozenset(arr.sharding.device_set)
jax_arrays[device_set]["arrays"].append(arr)
jax_arrays[device_set]["indices"].append(index)
jax_arrays[device_set]["dst_shardings"].append(dst_sharding)
else:
non_jax_arrays.append(arr)
non_jax_array_dst_shardings.append(dst_sharding)

if non_jax_arrays:
non_jax_arrays = jax.device_put(non_jax_arrays, non_jax_array_dst_shardings)
non_jax_arrays["arrays"].append(arr)
non_jax_arrays["indices"].append(index)
non_jax_arrays["dst_shardings"].append(dst_sharding)

if non_jax_arrays["arrays"]:
non_jax_arrays["arrays"] = jax.device_put(
non_jax_arrays["arrays"],
non_jax_arrays["dst_shardings"],
donate=donate,
may_alias=may_alias,
)

if jax_arrays:
for array_info in jax_arrays.values():
get_resharding_plan_func = (
_get_resharding_plan_cached
if cache_resharding_plans
else _get_resharding_plan
)
jax_arrays = get_resharding_plan_func(
tuple(arr.aval for arr in jax_arrays),
tuple(arr.sharding for arr in jax_arrays),
tuple(jax_array_dst_shardings),
array_info["arrays"] = get_resharding_plan_func(
tuple(arr.aval for arr in array_info["arrays"]),
tuple(arr.sharding for arr in array_info["arrays"]),
tuple(array_info["dst_shardings"]),
donate,
).execute(tuple(jax_arrays))
).execute(tuple(array_info["arrays"]))

result = []
jax_iter = iter(jax_arrays)
non_jax_iter = iter(non_jax_arrays)
result = [None] * len(flat_x)
for arr, idx in zip(non_jax_arrays["arrays"], non_jax_arrays["indices"]):
result[idx] = arr
for array_info in jax_arrays.values():
for arr, idx in zip(array_info["arrays"], array_info["indices"]):
result[idx] = arr

for arr in flat_x:
if isinstance(arr, jax.Array):
result.append(next(jax_iter))
else:
result.append(next(non_jax_iter))
return jax.tree.unflatten(tree_def, result)