Skip to content

Commit b7f44b7

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Hide internal helper functions for 1p only in reshard.py.
PiperOrigin-RevId: 888783253
1 parent 8ea9ef7 commit b7f44b7

File tree

1 file changed

+0
-251
lines changed

1 file changed

+0
-251
lines changed

tunix/rl/reshard.py

Lines changed: 0 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -58,257 +58,6 @@ def wait():
5858
threading.Thread(target=wait).start()
5959

6060

61-
def _identity(x):
62-
return x
63-
64-
65-
INTERMEDIATE_SPLIT_SUFFIX = '_intermediate_split'
66-
INTERMEDIATE_REPLICA_SUFFIX = '_intermediate_replica'
67-
68-
69-
def _maybe_find_intermediate_sharding(source_sharding, target_sharding):
70-
"""Maybe finds an intermediate sharding to reshard to before target sharding.
71-
72-
This function tries to find an intermediate sharding that can be used to
73-
reshard the source sharding to the target sharding. This is useful when
74-
resharding from a source sharding to a target sharding that requires an
75-
all-gather, which can be expensive.
76-
77-
For example, consider resharding an array from src_sharding (e.g., [fsdp: 8,
78-
tp: 1]) to target_sharding (e.g., [fsdp: 1, tp: 4]). In this case, the source
79-
has a larger sharding factor (8) than the target largest sharding factor (4)
80-
on the tp dimension.
81-
To avoid an expensive all-gather, we can introduce an intermediate sharding
82-
(e.g., [fsdp_split: 4, fsdp_replica: 2, tp: 1]). This intermediate sharding
83-
allows us to reshard the source array by still sharding along the fsdp
84-
dimension and replicating it on the remaining devices. Then we can just
85-
reshard any replica of the source to the target as normal.
86-
87-
Args:
88-
source_sharding: The source sharding.
89-
target_sharding: The target sharding.
90-
91-
Returns:
92-
An intermediate sharding, or None if no intermediate sharding can be found.
93-
"""
94-
if not isinstance(
95-
source_sharding, jax.sharding.NamedSharding
96-
) or not isinstance(target_sharding, jax.sharding.NamedSharding):
97-
logging.vlog(
98-
2,
99-
'None-NamedSharding does not need intermediate sharding.'
100-
f' {source_sharding=}, {target_sharding=}',
101-
)
102-
return None
103-
src_mesh = source_sharding.mesh
104-
dst_mesh = target_sharding.mesh
105-
106-
def _get_sharding_dims(sharding, mesh):
107-
sharding_dims = {}
108-
for i, axis_name in enumerate(sharding.spec):
109-
if axis_name is None:
110-
sharding_dims[(i, None)] = 1
111-
elif isinstance(axis_name, str):
112-
sharding_dims[(i, mesh.axis_names.index(axis_name))] = mesh.shape[
113-
axis_name
114-
]
115-
elif isinstance(axis_name, (tuple, list)):
116-
for _, axis in enumerate(axis_name):
117-
if axis is None:
118-
sharding_dims[(i, None)] = 1
119-
# Only handles two-level logical axis rules for now.
120-
elif isinstance(axis, str):
121-
sharding_dims[(i, mesh.axis_names.index(axis))] = (
122-
mesh.shape[axis]
123-
)
124-
else:
125-
raise ValueError(f'Unsupported axis name: {axis_name}')
126-
else:
127-
raise ValueError(
128-
f'Unsupported axis name: {axis_name} with type {type(axis_name)}'
129-
)
130-
131-
largest_shards = max(sharding_dims.values()) if len(sharding_dims) else 1
132-
if len(sharding_dims) < len(mesh.shape):
133-
for mi, mesh_axis in enumerate(mesh.axis_names):
134-
matched = any(mesh_axis == keys[1] for keys in sharding_dims)
135-
if not matched:
136-
sharding_dims[(None, mi)] = 1
137-
return sharding_dims, largest_shards
138-
139-
src_sharding_dims, src_largest_shards = _get_sharding_dims(
140-
source_sharding, src_mesh
141-
)
142-
dst_sharding_dims, dst_largest_shards = _get_sharding_dims(
143-
target_sharding, dst_mesh
144-
)
145-
# Not able to handle resharding with undividable shardings.
146-
if src_largest_shards % dst_largest_shards != 0:
147-
logging.debug(
148-
'Resharding with undividable shardings is not optimized with'
149-
' experimental pre-reshard.'
150-
' source_sharding=%s, target_sharding=%s',
151-
source_sharding,
152-
target_sharding,
153-
)
154-
return None
155-
156-
total_source_sharding_dims = math.prod(list(src_sharding_dims.values()))
157-
total_dst_sharding_dims = math.prod(list(dst_sharding_dims.values()))
158-
if (
159-
total_source_sharding_dims <= total_dst_sharding_dims
160-
or total_source_sharding_dims % total_dst_sharding_dims != 0
161-
):
162-
return None
163-
164-
new_split_dim_shards = None
165-
new_split_axis = None
166-
replicas = src_largest_shards // dst_largest_shards
167-
168-
# Find gcd(src_dim_shards, dst_dim_shards),
169-
# If all of them are 1s, an all-gather is needed as the single replica of
170-
# the source cannot be presented by any sharded form on the target devices.
171-
gcd_shards = []
172-
for (sharding_mesh_axis_idx, src_dim_shards), (_, dst_dim_shards) in zip(
173-
src_sharding_dims.items(), dst_sharding_dims.items()
174-
):
175-
gcd_dim_shards = math.gcd(src_dim_shards, dst_dim_shards)
176-
if gcd_dim_shards == 1:
177-
if (
178-
src_dim_shards > dst_dim_shards
179-
and src_dim_shards == src_largest_shards
180-
):
181-
new_split_axis = sharding_mesh_axis_idx
182-
new_split_dim_shards = (src_dim_shards // replicas, replicas)
183-
gcd_shards.append(gcd_dim_shards)
184-
185-
if math.prod(gcd_shards) != 1 or new_split_axis is None:
186-
return None
187-
188-
# Generate the intermediate sharding.
189-
new_split_mesh_axis_name = (
190-
src_mesh.axis_names[new_split_axis[1]] + INTERMEDIATE_SPLIT_SUFFIX
191-
)
192-
new_split_mesh_replica_axis_name = (
193-
src_mesh.axis_names[new_split_axis[1]] + INTERMEDIATE_REPLICA_SUFFIX
194-
)
195-
intermediate_mesh = jax.sharding.Mesh(
196-
src_mesh.devices.reshape(
197-
tuple(
198-
list(src_mesh.devices.shape[: new_split_axis[1]])
199-
+ [new_split_dim_shards[0], new_split_dim_shards[1]]
200-
+ list(src_mesh.devices.shape[new_split_axis[1] + 1 :])
201-
)
202-
),
203-
axis_names=tuple(
204-
list(src_mesh.axis_names[: new_split_axis[1]])
205-
+ [new_split_mesh_axis_name, new_split_mesh_replica_axis_name]
206-
+ list(src_mesh.axis_names[new_split_axis[1] + 1 :])
207-
),
208-
)
209-
210-
intermediate_spec = tuple(
211-
list(source_sharding.spec[: new_split_axis[0]])
212-
+ [new_split_mesh_axis_name]
213-
+ list(source_sharding.spec[new_split_axis[0] + 1 :])
214-
)
215-
intermediate_sharding = jax.sharding.NamedSharding(
216-
intermediate_mesh,
217-
jax.sharding.PartitionSpec(*intermediate_spec),
218-
memory_kind=source_sharding.memory_kind,
219-
)
220-
return intermediate_sharding
221-
222-
223-
def _experimental_pre_reshard(splitfn, src_pytree, target_shardings):
224-
"""Simple heuristic to determine if resharding with replicated all-gather is needed.
225-
226-
A replicated all-gather often results to heavy HBM occupation which we need to
227-
avoid. This funciton currently only handles the case like resharding from
228-
[fsdp: 8, tp: 1] to [fsdp: 1, tp: 4].
229-
We will improve the coverage on more complex cases along the development.
230-
231-
Args:
232-
splitfn: The split function.
233-
src_pytree: The source jax Array.
234-
target_shardings: The target sharding.
235-
236-
Returns:
237-
Pre-resharded src_pytree.
238-
"""
239-
src_shardings = jax.tree_util.tree_map(
240-
lambda x: x.sharding,
241-
src_pytree,
242-
)
243-
intermediate_shardings = jax.tree_util.tree_map(
244-
_maybe_find_intermediate_sharding,
245-
src_shardings,
246-
target_shardings,
247-
)
248-
249-
src_leaves_with_path, src_treedef = jax.tree_util.tree_flatten_with_path(
250-
src_pytree
251-
)
252-
intermediate_sharding_leaves_with_path, _ = (
253-
jax.tree_util.tree_flatten_with_path(intermediate_shardings)
254-
)
255-
intermediate_sharding_leaves_with_path = {
256-
path: intermediate_sharding
257-
for path, intermediate_sharding in intermediate_sharding_leaves_with_path
258-
}
259-
260-
to_split_src_pytree_leaves = {}
261-
to_split_src_pytree_leaves_indexes = {}
262-
to_split_intermediate_sharding_leaves = {}
263-
264-
intermediate_mesh = None
265-
to_update_src_pytree_leaves = []
266-
267-
for i, (path, src) in enumerate(src_leaves_with_path):
268-
to_update_src_pytree_leaves.append(src)
269-
if intermediate_sharding := intermediate_sharding_leaves_with_path.get(
270-
path, None
271-
):
272-
# The to_split_axis should always be the same along all the intermediate
273-
# shardings.
274-
intermediate_mesh = intermediate_sharding.mesh
275-
to_split_src_pytree_leaves.setdefault(intermediate_mesh, []).append(src)
276-
to_split_src_pytree_leaves_indexes.setdefault(intermediate_mesh, []).append(i)
277-
to_split_intermediate_sharding_leaves.setdefault(intermediate_mesh, []).append(intermediate_sharding)
278-
279-
if intermediate_mesh is None:
280-
# No pre-resharding is needed.
281-
return src_pytree
282-
283-
for _intermediate_mesh in to_split_src_pytree_leaves.keys():
284-
to_split_axis = None
285-
for axis_name in _intermediate_mesh.axis_names:
286-
if axis_name.endswith(INTERMEDIATE_REPLICA_SUFFIX):
287-
to_split_axis = axis_name
288-
break
289-
assert (
290-
to_split_axis is not None
291-
), f'No replica axis found in the intermediate mesh {_intermediate_mesh}.'
292-
293-
temp_source = jax.jit(
294-
_identity,
295-
out_shardings=to_split_intermediate_sharding_leaves[_intermediate_mesh],
296-
)(to_split_src_pytree_leaves[_intermediate_mesh])
297-
298-
# Update the to_split_src_pytree_leaves with the new splitted array.
299-
updated_to_split_src_pytree_leaves, *_ = splitfn(temp_source, to_split_axis)
300-
301-
for i in range(len(to_split_src_pytree_leaves_indexes[_intermediate_mesh])):
302-
to_update_src_pytree_leaves[
303-
to_split_src_pytree_leaves_indexes[_intermediate_mesh][i]
304-
] = updated_to_split_src_pytree_leaves[i]
305-
306-
updated_src_pytree = jax.tree_util.tree_unflatten(
307-
src_treedef, to_update_src_pytree_leaves
308-
)
309-
return updated_src_pytree
310-
311-
31261
#
31362

31463

0 commit comments

Comments
 (0)