Skip to content

Commit 2a4a0e8

Browse files
bixia1Google-ML-Automation
authored andcommitted
[jax:custom_partitioning] Implement SdyShardingRule to support
Shardy custom_partitioning. The parsing of the sharding rule string very closely follows how einops parses their rules in einops/parsing.py. When a SdyShardingRule object is constructed, we check the syntax of the Einsum like notation string and its consistency with the user provided factor_sizes, and report errors accordingly. This is done during f.def_partition. When SdyShardingRule.build is called, during JAX to MLIR lowering, we check the consistency between the Einsum like notation string, the factor_sizes and the MLIR operation, and report errors accordingly. PiperOrigin-RevId: 703187962
1 parent f73fa7a commit 2a4a0e8

File tree

4 files changed

+787
-0
lines changed

4 files changed

+787
-0
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ py_library_providing_imports_info(
193193
"_src/custom_batching.py",
194194
"_src/custom_derivatives.py",
195195
"_src/custom_partitioning.py",
196+
"_src/custom_partitioning_sharding_rule.py",
196197
"_src/custom_transpose.py",
197198
"_src/debugging.py",
198199
"_src/dispatch.py",
Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Implements SdyShardingRule."""
16+
17+
from collections import OrderedDict
18+
19+
from jax._src.lib.mlir import ir
20+
from jax._src.lib.mlir.dialects import sdy
21+
22+
23+
_CompoundFactor = tuple[str, ...]
24+
_DimMapping = tuple[str | _CompoundFactor, ...]
25+
26+
# A single character replacement for ... to simplify parsing.
27+
_ELLIPSIS: str = "…"
28+
29+
# A prefix for names of batching dimension factors, used for expanding the
30+
# leading ... into factors.
31+
_BATCHING_DIM_FACTOR_PREFIX = "?"
32+
33+
def _get_batching_dim_factor_name(batch_dim_order : int):
34+
"""Constructs a factor name for a batching dimension.
35+
36+
We expand the leading ... into factors representing the batching dimensions
37+
to support building the MLIR representation for the sharding rule. For this
38+
reason, we construct a factor name that won't be used by users for the
39+
batching dimensions.
40+
"""
41+
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}"
42+
43+
def _parse_values(
44+
rule: str,
45+
) -> tuple[_DimMapping, ...]:
46+
"""Parses the LHS or RHS of an Einsum notation like string.
47+
48+
Converts each operand or result in the Einsum notation like string to a tuple
49+
of _DimMapping. This very closely follows how einops parses their rules in
50+
einops/parsing.py.
51+
52+
Args:
53+
rule: The Einsum notation for the operands or results of an operation.
54+
55+
Returns:
56+
The tuple of values.
57+
58+
Raises:
59+
ValueError: If the rule is not balanced or contains unknown characters.
60+
"""
61+
62+
# Remove unnecessary spaces in the rule to simplify the parsing process.
63+
words = rule.split()
64+
rule = " ".join(words)
65+
66+
# Similar to einops rules, an empty LHS/RHS has a single scalar value.
67+
if not rule:
68+
return ((),)
69+
70+
all_values = []
71+
# Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the
72+
# value may have 0 or more leading dimensions.
73+
value = []
74+
current_factor = None
75+
# A value of None indicates the current dimension is not a compound dimension,
76+
# while a value of [] indicates that we have just started parsing a compound
77+
# dimension.
78+
current_compound_dim: list[str] | None = None
79+
80+
def add_factor(x):
81+
if current_compound_dim is None:
82+
value.append(x)
83+
else:
84+
current_compound_dim.append(x)
85+
86+
for char in rule:
87+
if char == _ELLIPSIS:
88+
if (current_factor is not None or current_compound_dim is not None
89+
or value):
90+
raise ValueError(
91+
"Ellipsis can only be used at the beginning of a dimension")
92+
add_factor(_ELLIPSIS)
93+
continue
94+
if char in "(), ":
95+
if current_factor is not None:
96+
add_factor(current_factor)
97+
current_factor = None
98+
if char == "(":
99+
if current_compound_dim is not None:
100+
raise ValueError(
101+
"Compound factors should be one level, nested brackets are not"
102+
" allowed")
103+
current_compound_dim = []
104+
elif char == ")":
105+
if current_compound_dim is None:
106+
raise ValueError("Brackets are not balanced")
107+
if len(current_compound_dim) <= 1:
108+
raise ValueError("Brackets should contain at least two factors")
109+
value.append(tuple(current_compound_dim))
110+
current_compound_dim = None
111+
elif char == ",":
112+
all_values.append(tuple(value))
113+
value = []
114+
elif char == "_" or char.isdigit() or char.isalpha():
115+
if current_factor is None:
116+
if str.isdigit(char):
117+
raise ValueError(f"Factor names have to start with a letter, but got '{char}'")
118+
current_factor = char
119+
else:
120+
current_factor += char
121+
else:
122+
raise ValueError(f"Unknown character '{char}'")
123+
124+
if current_compound_dim is not None:
125+
raise ValueError(f"Brackets are not balanced in rule: '{rule}'")
126+
if current_factor is not None:
127+
add_factor(current_factor)
128+
all_values.append(tuple(value))
129+
130+
return tuple(all_values)
131+
132+
133+
class SdyShardingRule:
134+
"""A representation for Shardy sharding rule.
135+
136+
A SdyShardingRule includes an Enisum notation like string and an optional
137+
list of factor sizes. A factor is a name in the Einsum notation. If a factor
138+
is only used in compound factors, its size must be specified.
139+
140+
SdyShardingRule examples:
141+
142+
* Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k')
143+
* Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k')
144+
* A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j')
145+
* Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2)
146+
* An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...')
147+
"""
148+
149+
def __init__(self, rule: str, **factor_sizes):
150+
"""Constructs a SdyShardingRule object from the Einsum notation like string.
151+
152+
This is done by verifying that the input Einsum notation like string and
153+
with optional factor sizes represents a valid sharding rule and converting
154+
it to an internal representation.
155+
156+
Args:
157+
rule: The Einsum notation like string for an operation.
158+
**factor_sizes: The optional factor sizes.
159+
160+
Raises:
161+
ValueError: If there is any problem with the rule or factor_sizes.
162+
"""
163+
if not isinstance(rule, str):
164+
raise TypeError(f"rule must be a str, but got {type(rule)}")
165+
if not all(isinstance(size, int) for size in factor_sizes.values()):
166+
raise TypeError(
167+
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
168+
169+
# Replace ... with a single char to simplify parsing.
170+
if _ELLIPSIS in rule:
171+
raise ValueError(f"Unknown character '{_ELLIPSIS}'")
172+
if "." in rule:
173+
rule = rule.replace("...", _ELLIPSIS)
174+
if "." in rule:
175+
raise ValueError("Character '.' must be used inside ellipsis '...'")
176+
177+
try:
178+
operands, results = rule.split("->")
179+
except ValueError as e:
180+
raise ValueError(f"There is no -> in rule: '{rule}'") from e
181+
182+
self.operands = _parse_values(operands)
183+
self.results = _parse_values(results)
184+
185+
# Find all factors and mark whether their size can be inferred.
186+
factors_inferrable = dict()
187+
for value in self.operands + self.results:
188+
for dim in value:
189+
if dim == _ELLIPSIS:
190+
continue
191+
if isinstance(dim, str):
192+
factors_inferrable[dim] = True
193+
else:
194+
for factor in dim:
195+
if factor not in factors_inferrable.keys():
196+
factors_inferrable[factor] = False
197+
198+
# Check that factors in factor_sizes are used in the rule.
199+
for factor in factor_sizes:
200+
if factor not in factors_inferrable:
201+
raise ValueError(
202+
f"Factor {factor} is not used in the rule, but size is provided")
203+
204+
# Check that factors that are used for a whole dimension aren't in
205+
# factor_sizes and factors that are never used for a whole dimension are
206+
# in factor_sizes.
207+
for factor, inferrable in factors_inferrable.items():
208+
if factor not in factor_sizes and not inferrable:
209+
raise ValueError(
210+
f"Factor {factor} is only used in compound factors; must specify"
211+
" its size")
212+
if factor in factor_sizes and inferrable:
213+
raise ValueError(
214+
f"Factor {factor} represents a whole dimension; do not specify its"
215+
" size")
216+
217+
self.factor_sizes = factor_sizes
218+
219+
def __str__(self):
220+
return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})"
221+
222+
def build(
223+
self,
224+
operand_types: list[ir.Type],
225+
result_types: list[ir.Type],) -> ir.Attribute:
226+
"""Builds the MLIR representation for the sharding rule.
227+
228+
This is done by verifying that the rule is consistent with the types of
229+
the operation and converting the Einsum notation like string to
230+
OpShardingRuleAttr.
231+
"""
232+
if len(self.operands) != len(operand_types):
233+
raise ValueError(
234+
f"Sharding rule has {len(self.operands)} operands, but the operation"
235+
f" has {len(operand_types)} operands"
236+
)
237+
if len(self.results) != len(result_types):
238+
raise ValueError(
239+
f"Sharding rule has {len(self.results)} results, but the operation"
240+
f" has {len(result_types)} results"
241+
)
242+
243+
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
244+
types = operand_types + result_types
245+
UNKNOWN = -1 # Representation for unknown factor size or factor index.
246+
247+
def get_message_for_value(i):
248+
if i >= len(operand_types):
249+
return f"{i - len(operand_types)}th result"
250+
else:
251+
return f"{i}th operand"
252+
253+
def get_rank_for_value(i):
254+
return ir.ShapedType(types[i]).rank
255+
256+
def get_size_for_value_dim(i, j):
257+
return ir.ShapedType(types[i]).shape[j]
258+
259+
def add_factor(factor, size):
260+
"""Adds a factor to factors_to_indices_sizes.
261+
262+
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
263+
if a factor is first used as in a compound factor and then used for a
264+
whole dimension.
265+
"""
266+
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
267+
if factor_index != UNKNOWN:
268+
# Not the first time seeing the factor.
269+
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
270+
factor_or_batching_dim = (
271+
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
272+
else f"Batching dimension {factor[1:]}")
273+
raise ValueError(
274+
f"{factor_or_batching_dim} corresponds to two sizes:"
275+
f" {factor_size} and {size}")
276+
if size != UNKNOWN and factor_size == UNKNOWN:
277+
factors_to_indices_sizes[factor] = [factor_index, size]
278+
else:
279+
# First time seeing the factor.
280+
factor_index = len(factors_to_indices_sizes)
281+
factors_to_indices_sizes[factor] = [factor_index, size]
282+
283+
def add_batching_dim_factor(batch_dim_order, factor_size):
284+
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
285+
add_factor(ellipsis_batch_dim_name, factor_size)
286+
287+
def build_dim_mapping_for_compound_factors(i, j, factors):
288+
accumulated_size = 1
289+
all_indices = []
290+
for factor in factors:
291+
factor_index, factor_size = factors_to_indices_sizes[factor]
292+
accumulated_size *= factor_size
293+
all_indices.append(factor_index)
294+
295+
dim_size = get_size_for_value_dim(i, j)
296+
if accumulated_size != dim_size:
297+
raise ValueError(
298+
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
299+
f" the size {accumulated_size} derived from the compound factors"
300+
f" {factors}")
301+
302+
return sdy.DimMappingAttr.get(factor_indices=all_indices)
303+
304+
# Add factors and their sizes in the order they appear in the rule,
305+
# including the batching dimensions represented by ellipsis.
306+
ellipsis_rank = None
307+
for i, value in enumerate(self.operands + self.results):
308+
if value and value[0] == _ELLIPSIS:
309+
has_ellipsis = True
310+
value = value[1:]
311+
else:
312+
has_ellipsis = False
313+
rule_rank = len(value)
314+
op_rank = get_rank_for_value(i)
315+
# The number of dimensions represented by ellipsis.
316+
current_ellipsis_rank = 0
317+
if has_ellipsis and op_rank > rule_rank:
318+
current_ellipsis_rank = op_rank - rule_rank
319+
if has_ellipsis:
320+
if ellipsis_rank is None:
321+
ellipsis_rank = current_ellipsis_rank
322+
elif ellipsis_rank != current_ellipsis_rank:
323+
raise ValueError(
324+
"Ellipsis represents different number of leading dimensions"
325+
f" {ellipsis_rank} and {current_ellipsis_rank}")
326+
rule_rank += current_ellipsis_rank
327+
if rule_rank != op_rank:
328+
msg = get_message_for_value(i)
329+
raise ValueError(
330+
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
331+
f" {msg} has rank {op_rank}")
332+
333+
for j in range(current_ellipsis_rank):
334+
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
335+
336+
for j, dim in enumerate(value):
337+
if isinstance(dim, str):
338+
add_factor(
339+
dim, get_size_for_value_dim(i, j + current_ellipsis_rank))
340+
else:
341+
for factor in dim:
342+
add_factor(factor, self.factor_sizes.get(factor, UNKNOWN))
343+
344+
# Build the tensor mappings for each operand and result.
345+
tensor_mappings = []
346+
for i, value in enumerate(self.operands + self.results):
347+
dim_mappings = []
348+
349+
if value and value[0] == _ELLIPSIS:
350+
value = value[1:]
351+
if ellipsis_rank is None:
352+
current_ellipsis_rank = 0
353+
else:
354+
current_ellipsis_rank = ellipsis_rank
355+
else:
356+
current_ellipsis_rank = 0
357+
358+
for j in range(current_ellipsis_rank):
359+
dim_mappings.append(
360+
sdy.DimMappingAttr.get(factor_indices=[
361+
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
362+
363+
for j, dim in enumerate(value):
364+
if isinstance(dim, str):
365+
dim_mappings.append(
366+
sdy.DimMappingAttr.get(
367+
factor_indices=[factors_to_indices_sizes[dim][0]]))
368+
else:
369+
dim_mappings.append(
370+
build_dim_mapping_for_compound_factors(
371+
i, j + current_ellipsis_rank, dim))
372+
373+
tensor_mappings.append(
374+
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
375+
376+
op_sharding_rule = sdy.OpShardingRuleAttr.get(
377+
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
378+
operand_mappings=tensor_mappings[0:len(operand_types)],
379+
result_mappings=tensor_mappings[len(operand_types):])
380+
return op_sharding_rule

0 commit comments

Comments
 (0)