@@ -58,257 +58,6 @@ def wait():
5858 threading .Thread (target = wait ).start ()
5959
6060
61- def _identity (x ):
62- return x
63-
64-
65- INTERMEDIATE_SPLIT_SUFFIX = '_intermediate_split'
66- INTERMEDIATE_REPLICA_SUFFIX = '_intermediate_replica'
67-
68-
69- def _maybe_find_intermediate_sharding (source_sharding , target_sharding ):
70- """Maybe finds an intermediate sharding to reshard to before target sharding.
71-
72- This function tries to find an intermediate sharding that can be used to
73- reshard the source sharding to the target sharding. This is useful when
74- resharding from a source sharding to a target sharding that requires an
75- all-gather, which can be expensive.
76-
77- For example, consider resharding an array from src_sharding (e.g., [fsdp: 8,
78- tp: 1]) to target_sharding (e.g., [fsdp: 1, tp: 4]). In this case, the source
79- has a larger sharding factor (8) than the target largest sharding factor (4)
80- on the tp dimension.
81- To avoid an expensive all-gather, we can introduce an intermediate sharding
82- (e.g., [fsdp_split: 4, fsdp_replica: 2, tp: 1]). This intermediate sharding
83- allows us to reshard the source array by still sharding along the fsdp
84- dimension and replicating it on the remaining devices. Then we can just
85- reshard any replica of the source to the target as normal.
86-
87- Args:
88- source_sharding: The source sharding.
89- target_sharding: The target sharding.
90-
91- Returns:
92- An intermediate sharding, or None if no intermediate sharding can be found.
93- """
94- if not isinstance (
95- source_sharding , jax .sharding .NamedSharding
96- ) or not isinstance (target_sharding , jax .sharding .NamedSharding ):
97- logging .vlog (
98- 2 ,
99- 'None-NamedSharding does not need intermediate sharding.'
100- f' { source_sharding = } , { target_sharding = } ' ,
101- )
102- return None
103- src_mesh = source_sharding .mesh
104- dst_mesh = target_sharding .mesh
105-
106- def _get_sharding_dims (sharding , mesh ):
107- sharding_dims = {}
108- for i , axis_name in enumerate (sharding .spec ):
109- if axis_name is None :
110- sharding_dims [(i , None )] = 1
111- elif isinstance (axis_name , str ):
112- sharding_dims [(i , mesh .axis_names .index (axis_name ))] = mesh .shape [
113- axis_name
114- ]
115- elif isinstance (axis_name , (tuple , list )):
116- for _ , axis in enumerate (axis_name ):
117- if axis is None :
118- sharding_dims [(i , None )] = 1
119- # Only handles two-level logical axis rules for now.
120- elif isinstance (axis , str ):
121- sharding_dims [(i , mesh .axis_names .index (axis ))] = (
122- mesh .shape [axis ]
123- )
124- else :
125- raise ValueError (f'Unsupported axis name: { axis_name } ' )
126- else :
127- raise ValueError (
128- f'Unsupported axis name: { axis_name } with type { type (axis_name )} '
129- )
130-
131- largest_shards = max (sharding_dims .values ()) if len (sharding_dims ) else 1
132- if len (sharding_dims ) < len (mesh .shape ):
133- for mi , mesh_axis in enumerate (mesh .axis_names ):
134- matched = any (mesh_axis == keys [1 ] for keys in sharding_dims )
135- if not matched :
136- sharding_dims [(None , mi )] = 1
137- return sharding_dims , largest_shards
138-
139- src_sharding_dims , src_largest_shards = _get_sharding_dims (
140- source_sharding , src_mesh
141- )
142- dst_sharding_dims , dst_largest_shards = _get_sharding_dims (
143- target_sharding , dst_mesh
144- )
145- # Not able to handle resharding with undividable shardings.
146- if src_largest_shards % dst_largest_shards != 0 :
147- logging .debug (
148- 'Resharding with undividable shardings is not optimized with'
149- ' experimental pre-reshard.'
150- ' source_sharding=%s, target_sharding=%s' ,
151- source_sharding ,
152- target_sharding ,
153- )
154- return None
155-
156- total_source_sharding_dims = math .prod (list (src_sharding_dims .values ()))
157- total_dst_sharding_dims = math .prod (list (dst_sharding_dims .values ()))
158- if (
159- total_source_sharding_dims <= total_dst_sharding_dims
160- or total_source_sharding_dims % total_dst_sharding_dims != 0
161- ):
162- return None
163-
164- new_split_dim_shards = None
165- new_split_axis = None
166- replicas = src_largest_shards // dst_largest_shards
167-
168- # Find gcd(src_dim_shards, dst_dim_shards),
169- # If all of them are 1s, an all-gather is needed as the single replica of
170- # the source cannot be presented by any sharded form on the target devices.
171- gcd_shards = []
172- for (sharding_mesh_axis_idx , src_dim_shards ), (_ , dst_dim_shards ) in zip (
173- src_sharding_dims .items (), dst_sharding_dims .items ()
174- ):
175- gcd_dim_shards = math .gcd (src_dim_shards , dst_dim_shards )
176- if gcd_dim_shards == 1 :
177- if (
178- src_dim_shards > dst_dim_shards
179- and src_dim_shards == src_largest_shards
180- ):
181- new_split_axis = sharding_mesh_axis_idx
182- new_split_dim_shards = (src_dim_shards // replicas , replicas )
183- gcd_shards .append (gcd_dim_shards )
184-
185- if math .prod (gcd_shards ) != 1 or new_split_axis is None :
186- return None
187-
188- # Generate the intermediate sharding.
189- new_split_mesh_axis_name = (
190- src_mesh .axis_names [new_split_axis [1 ]] + INTERMEDIATE_SPLIT_SUFFIX
191- )
192- new_split_mesh_replica_axis_name = (
193- src_mesh .axis_names [new_split_axis [1 ]] + INTERMEDIATE_REPLICA_SUFFIX
194- )
195- intermediate_mesh = jax .sharding .Mesh (
196- src_mesh .devices .reshape (
197- tuple (
198- list (src_mesh .devices .shape [: new_split_axis [1 ]])
199- + [new_split_dim_shards [0 ], new_split_dim_shards [1 ]]
200- + list (src_mesh .devices .shape [new_split_axis [1 ] + 1 :])
201- )
202- ),
203- axis_names = tuple (
204- list (src_mesh .axis_names [: new_split_axis [1 ]])
205- + [new_split_mesh_axis_name , new_split_mesh_replica_axis_name ]
206- + list (src_mesh .axis_names [new_split_axis [1 ] + 1 :])
207- ),
208- )
209-
210- intermediate_spec = tuple (
211- list (source_sharding .spec [: new_split_axis [0 ]])
212- + [new_split_mesh_axis_name ]
213- + list (source_sharding .spec [new_split_axis [0 ] + 1 :])
214- )
215- intermediate_sharding = jax .sharding .NamedSharding (
216- intermediate_mesh ,
217- jax .sharding .PartitionSpec (* intermediate_spec ),
218- memory_kind = source_sharding .memory_kind ,
219- )
220- return intermediate_sharding
221-
222-
223- def _experimental_pre_reshard (splitfn , src_pytree , target_shardings ):
224- """Simple heuristic to determine if resharding with replicated all-gather is needed.
225-
226- A replicated all-gather often results to heavy HBM occupation which we need to
227- avoid. This funciton currently only handles the case like resharding from
228- [fsdp: 8, tp: 1] to [fsdp: 1, tp: 4].
229- We will improve the coverage on more complex cases along the development.
230-
231- Args:
232- splitfn: The split function.
233- src_pytree: The source jax Array.
234- target_shardings: The target sharding.
235-
236- Returns:
237- Pre-resharded src_pytree.
238- """
239- src_shardings = jax .tree_util .tree_map (
240- lambda x : x .sharding ,
241- src_pytree ,
242- )
243- intermediate_shardings = jax .tree_util .tree_map (
244- _maybe_find_intermediate_sharding ,
245- src_shardings ,
246- target_shardings ,
247- )
248-
249- src_leaves_with_path , src_treedef = jax .tree_util .tree_flatten_with_path (
250- src_pytree
251- )
252- intermediate_sharding_leaves_with_path , _ = (
253- jax .tree_util .tree_flatten_with_path (intermediate_shardings )
254- )
255- intermediate_sharding_leaves_with_path = {
256- path : intermediate_sharding
257- for path , intermediate_sharding in intermediate_sharding_leaves_with_path
258- }
259-
260- to_split_src_pytree_leaves = {}
261- to_split_src_pytree_leaves_indexes = {}
262- to_split_intermediate_sharding_leaves = {}
263-
264- intermediate_mesh = None
265- to_update_src_pytree_leaves = []
266-
267- for i , (path , src ) in enumerate (src_leaves_with_path ):
268- to_update_src_pytree_leaves .append (src )
269- if intermediate_sharding := intermediate_sharding_leaves_with_path .get (
270- path , None
271- ):
272- # The to_split_axis should always be the same along all the intermediate
273- # shardings.
274- intermediate_mesh = intermediate_sharding .mesh
275- to_split_src_pytree_leaves .setdefault (intermediate_mesh , []).append (src )
276- to_split_src_pytree_leaves_indexes .setdefault (intermediate_mesh , []).append (i )
277- to_split_intermediate_sharding_leaves .setdefault (intermediate_mesh , []).append (intermediate_sharding )
278-
279- if intermediate_mesh is None :
280- # No pre-resharding is needed.
281- return src_pytree
282-
283- for _intermediate_mesh in to_split_src_pytree_leaves .keys ():
284- to_split_axis = None
285- for axis_name in _intermediate_mesh .axis_names :
286- if axis_name .endswith (INTERMEDIATE_REPLICA_SUFFIX ):
287- to_split_axis = axis_name
288- break
289- assert (
290- to_split_axis is not None
291- ), f'No replica axis found in the intermediate mesh { _intermediate_mesh } .'
292-
293- temp_source = jax .jit (
294- _identity ,
295- out_shardings = to_split_intermediate_sharding_leaves [_intermediate_mesh ],
296- )(to_split_src_pytree_leaves [_intermediate_mesh ])
297-
298- # Update the to_split_src_pytree_leaves with the new splitted array.
299- updated_to_split_src_pytree_leaves , * _ = splitfn (temp_source , to_split_axis )
300-
301- for i in range (len (to_split_src_pytree_leaves_indexes [_intermediate_mesh ])):
302- to_update_src_pytree_leaves [
303- to_split_src_pytree_leaves_indexes [_intermediate_mesh ][i ]
304- ] = updated_to_split_src_pytree_leaves [i ]
305-
306- updated_src_pytree = jax .tree_util .tree_unflatten (
307- src_treedef , to_update_src_pytree_leaves
308- )
309- return updated_src_pytree
310-
311-
31261#
31362
31463
0 commit comments