1414"""Experimental resharding API for elastic device sets."""
1515
1616import base64
17+ import collections
18+ import itertools
1719import json
1820from typing import Any , Dict , Sequence
1921
@@ -92,9 +94,7 @@ def _get_resharding_plan(
9294 donate : bool ,
9395) -> ReshardingPlanWrapper :
9496 """Returns a resharding plan for the given sharding task."""
95- return ReshardingPlanWrapper (
96- avals , old_shardings , new_shardings , donate
97- )
97+ return ReshardingPlanWrapper (avals , old_shardings , new_shardings , donate )
9898
9999
100100_get_resharding_plan_cached = lru_cache .lru_cache ()(_get_resharding_plan )
@@ -116,12 +116,11 @@ def reshard(
116116 (must be a tree prefix of `x`), representing the device(s) and sharding to
117117 which `x` should be sharded to. The result will be committed to the
118118 device(s) of the sharding.
119- donate: If `True`, donate all input arrays, which may reduce the
120- amount memory needed for resharding. Buffers donated to resharding should
121- not be reused.
122- may_alias: If `True`, may alias the input array with the output array.
123- May reduce the amount of memory needed for resharding. Not used at the
124- moment.
119+ donate: If `True`, donate all input arrays, which may reduce the amount
120+ memory needed for resharding. Buffers donated to resharding should not be
121+ reused.
122+ may_alias: If `True`, may alias the input array with the output array. May
123+ reduce the amount of memory needed for resharding. Not used at the moment.
125124 cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
126125 recreating plans for the same resharding operation. May improve
127126 performance for use cases where the same resharding operation is done many
@@ -137,43 +136,59 @@ def reshard(
137136 "reshard sharding" , tree_def , sharding
138137 )
139138
140- jax_arrays = []
141- jax_array_dst_shardings = []
142- non_jax_arrays = []
143- non_jax_array_dst_shardings = []
144- for arr , dst_sharding in zip (flat_x , flat_sharding ):
139+ # We must split the arrays into two groups:
140+ # 1. jax.Array
141+ # 2. non jax.Array
142+ # For jax.Array, we will use the ifrt client to get the resharding plan and
143+ # execute it.
144+ # These arrays must be further split into groups based on the device set of
145+ # the sharding, since plugin programs only supports execution on the same
146+ # device set.
147+ # For non jax.Array, we will use jax.device_put to put the array to the
148+ # destination devices.
149+ jax_arrays = collections .defaultdict (
150+ lambda : {"arrays" : [], "indices" : [], "dst_shardings" : []}
151+ )
152+ non_jax_arrays = {"arrays" : [], "indices" : [], "dst_shardings" : []}
153+ for index , (arr , dst_sharding ) in enumerate (zip (flat_x , flat_sharding )):
145154 if not isinstance (dst_sharding , jax .sharding .Sharding ):
146155 raise ValueError ("`sharding` must contain only `jax.sharding.Sharding`" )
147156 if isinstance (arr , jax .Array ):
148- jax_arrays .append (arr )
149- jax_array_dst_shardings .append (dst_sharding )
157+ key = frozenset (arr .sharding .device_set )
158+ jax_arrays [key ]["arrays" ].append (arr )
159+ jax_arrays [key ]["indices" ].append (index )
160+ jax_arrays [key ]["dst_shardings" ].append (dst_sharding )
150161 else :
151- non_jax_arrays .append (arr )
152- non_jax_array_dst_shardings .append (dst_sharding )
153-
154- if non_jax_arrays :
155- non_jax_arrays = jax .device_put (non_jax_arrays , non_jax_array_dst_shardings )
162+ non_jax_arrays ["arrays" ].append (arr )
163+ non_jax_arrays ["indices" ].append (index )
164+ non_jax_arrays ["dst_shardings" ].append (dst_sharding )
165+
166+ if non_jax_arrays ["arrays" ]:
167+ non_jax_arrays ["arrays" ] = jax .device_put (
168+ non_jax_arrays ["arrays" ],
169+ non_jax_arrays ["dst_shardings" ],
170+ donate = donate ,
171+ may_alias = may_alias ,
172+ )
156173
157- if jax_arrays :
174+ for array_info in jax_arrays . values () :
158175 get_resharding_plan_func = (
159176 _get_resharding_plan_cached
160177 if cache_resharding_plans
161178 else _get_resharding_plan
162179 )
163- jax_arrays = get_resharding_plan_func (
164- tuple (arr .aval for arr in jax_arrays ),
165- tuple (arr .sharding for arr in jax_arrays ),
166- tuple (jax_array_dst_shardings ),
180+ array_info [ "arrays" ] = get_resharding_plan_func (
181+ tuple (arr .aval for arr in array_info [ "arrays" ] ),
182+ tuple (arr .sharding for arr in array_info [ "arrays" ] ),
183+ tuple (array_info [ "dst_shardings" ] ),
167184 donate ,
168- ).execute (tuple (jax_arrays ))
185+ ).execute (tuple (array_info [ "arrays" ] ))
169186
170- result = []
171- jax_iter = iter (jax_arrays )
172- non_jax_iter = iter (non_jax_arrays )
187+ result = [None ] * len (flat_x )
188+ for arr , idx in zip (non_jax_arrays ["arrays" ], non_jax_arrays ["indices" ]):
189+ result [idx ] = arr
190+ for array_info in jax_arrays .values ():
191+ for arr , idx in zip (array_info ["arrays" ], array_info ["indices" ]):
192+ result [idx ] = arr
173193
174- for arr in flat_x :
175- if isinstance (arr , jax .Array ):
176- result .append (next (jax_iter ))
177- else :
178- result .append (next (non_jax_iter ))
179194 return jax .tree .unflatten (tree_def , result )
0 commit comments