|
| 1 | +# Copyright 2024 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Implements SdyShardingRule.""" |
| 16 | + |
| 17 | +from collections import OrderedDict |
| 18 | + |
| 19 | +from jax._src.lib.mlir import ir |
| 20 | +from jax._src.lib.mlir.dialects import sdy |
| 21 | + |
| 22 | + |
| 23 | +_CompoundFactor = tuple[str, ...] |
| 24 | +_DimMapping = tuple[str | _CompoundFactor, ...] |
| 25 | + |
| 26 | +# A single character replacement for ... to simplify parsing. |
| 27 | +_ELLIPSIS: str = "…" |
| 28 | + |
| 29 | +# A prefix for names of batching dimension factors, used for expanding the |
| 30 | +# leading ... into factors. |
| 31 | +_BATCHING_DIM_FACTOR_PREFIX = "?" |
| 32 | + |
| 33 | +def _get_batching_dim_factor_name(batch_dim_order : int): |
| 34 | + """Constructs a factor name for a batching dimension. |
| 35 | +
|
| 36 | + We expand the leading ... into factors representing the batching dimensions |
| 37 | + to support building the MLIR representation for the sharding rule. For this |
| 38 | + reason, we construct a factor name that won't be used by users for the |
| 39 | + batching dimensions. |
| 40 | + """ |
| 41 | + return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}" |
| 42 | + |
| 43 | +def _parse_values( |
| 44 | + rule: str, |
| 45 | +) -> tuple[_DimMapping, ...]: |
| 46 | + """Parses the LHS or RHS of an Einsum notation like string. |
| 47 | +
|
| 48 | + Converts each operand or result in the Einsum notation like string to a tuple |
| 49 | + of _DimMapping. This very closely follows how einops parses their rules in |
| 50 | + einops/parsing.py. |
| 51 | +
|
| 52 | + Args: |
| 53 | + rule: The Einsum notation for the operands or results of an operation. |
| 54 | +
|
| 55 | + Returns: |
| 56 | + The tuple of values. |
| 57 | +
|
| 58 | + Raises: |
| 59 | + ValueError: If the rule is not balanced or contains unknown characters. |
| 60 | + """ |
| 61 | + |
| 62 | + # Remove unnecessary spaces in the rule to simplify the parsing process. |
| 63 | + words = rule.split() |
| 64 | + rule = " ".join(words) |
| 65 | + |
| 66 | + # Similar to einops rules, an empty LHS/RHS has a single scalar value. |
| 67 | + if not rule: |
| 68 | + return ((),) |
| 69 | + |
| 70 | + all_values = [] |
| 71 | + # Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the |
| 72 | + # value may have 0 or more leading dimensions. |
| 73 | + value = [] |
| 74 | + current_factor = None |
| 75 | + # A value of None indicates the current dimension is not a compound dimension, |
| 76 | + # while a value of [] indicates that we have just started parsing a compound |
| 77 | + # dimension. |
| 78 | + current_compound_dim: list[str] | None = None |
| 79 | + |
| 80 | + def add_factor(x): |
| 81 | + if current_compound_dim is None: |
| 82 | + value.append(x) |
| 83 | + else: |
| 84 | + current_compound_dim.append(x) |
| 85 | + |
| 86 | + for char in rule: |
| 87 | + if char == _ELLIPSIS: |
| 88 | + if (current_factor is not None or current_compound_dim is not None |
| 89 | + or value): |
| 90 | + raise ValueError( |
| 91 | + "Ellipsis can only be used at the beginning of a dimension") |
| 92 | + add_factor(_ELLIPSIS) |
| 93 | + continue |
| 94 | + if char in "(), ": |
| 95 | + if current_factor is not None: |
| 96 | + add_factor(current_factor) |
| 97 | + current_factor = None |
| 98 | + if char == "(": |
| 99 | + if current_compound_dim is not None: |
| 100 | + raise ValueError( |
| 101 | + "Compound factors should be one level, nested brackets are not" |
| 102 | + " allowed") |
| 103 | + current_compound_dim = [] |
| 104 | + elif char == ")": |
| 105 | + if current_compound_dim is None: |
| 106 | + raise ValueError("Brackets are not balanced") |
| 107 | + if len(current_compound_dim) <= 1: |
| 108 | + raise ValueError("Brackets should contain at least two factors") |
| 109 | + value.append(tuple(current_compound_dim)) |
| 110 | + current_compound_dim = None |
| 111 | + elif char == ",": |
| 112 | + all_values.append(tuple(value)) |
| 113 | + value = [] |
| 114 | + elif char == "_" or char.isdigit() or char.isalpha(): |
| 115 | + if current_factor is None: |
| 116 | + if str.isdigit(char): |
| 117 | + raise ValueError(f"Factor names have to start with a letter, but got '{char}'") |
| 118 | + current_factor = char |
| 119 | + else: |
| 120 | + current_factor += char |
| 121 | + else: |
| 122 | + raise ValueError(f"Unknown character '{char}'") |
| 123 | + |
| 124 | + if current_compound_dim is not None: |
| 125 | + raise ValueError(f"Brackets are not balanced in rule: '{rule}'") |
| 126 | + if current_factor is not None: |
| 127 | + add_factor(current_factor) |
| 128 | + all_values.append(tuple(value)) |
| 129 | + |
| 130 | + return tuple(all_values) |
| 131 | + |
| 132 | + |
| 133 | +class SdyShardingRule: |
| 134 | + """A representation for Shardy sharding rule. |
| 135 | +
|
| 136 | + A SdyShardingRule includes an Enisum notation like string and an optional |
| 137 | + list of factor sizes. A factor is a name in the Einsum notation. If a factor |
| 138 | + is only used in compound factors, its size must be specified. |
| 139 | +
|
| 140 | + SdyShardingRule examples: |
| 141 | +
|
| 142 | + * Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k') |
| 143 | + * Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k') |
| 144 | + * A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j') |
| 145 | + * Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2) |
| 146 | + * An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...') |
| 147 | + """ |
| 148 | + |
| 149 | + def __init__(self, rule: str, **factor_sizes): |
| 150 | + """Constructs a SdyShardingRule object from the Einsum notation like string. |
| 151 | +
|
| 152 | + This is done by verifying that the input Einsum notation like string and |
| 153 | + with optional factor sizes represents a valid sharding rule and converting |
| 154 | + it to an internal representation. |
| 155 | +
|
| 156 | + Args: |
| 157 | + rule: The Einsum notation like string for an operation. |
| 158 | + **factor_sizes: The optional factor sizes. |
| 159 | +
|
| 160 | + Raises: |
| 161 | + ValueError: If there is any problem with the rule or factor_sizes. |
| 162 | + """ |
| 163 | + if not isinstance(rule, str): |
| 164 | + raise TypeError(f"rule must be a str, but got {type(rule)}") |
| 165 | + if not all(isinstance(size, int) for size in factor_sizes.values()): |
| 166 | + raise TypeError( |
| 167 | + f"factor_sizes must be a dict of str to int, but got {factor_sizes}") |
| 168 | + |
| 169 | + # Replace ... with a single char to simplify parsing. |
| 170 | + if _ELLIPSIS in rule: |
| 171 | + raise ValueError(f"Unknown character '{_ELLIPSIS}'") |
| 172 | + if "." in rule: |
| 173 | + rule = rule.replace("...", _ELLIPSIS) |
| 174 | + if "." in rule: |
| 175 | + raise ValueError("Character '.' must be used inside ellipsis '...'") |
| 176 | + |
| 177 | + try: |
| 178 | + operands, results = rule.split("->") |
| 179 | + except ValueError as e: |
| 180 | + raise ValueError(f"There is no -> in rule: '{rule}'") from e |
| 181 | + |
| 182 | + self.operands = _parse_values(operands) |
| 183 | + self.results = _parse_values(results) |
| 184 | + |
| 185 | + # Find all factors and mark whether their size can be inferred. |
| 186 | + factors_inferrable = dict() |
| 187 | + for value in self.operands + self.results: |
| 188 | + for dim in value: |
| 189 | + if dim == _ELLIPSIS: |
| 190 | + continue |
| 191 | + if isinstance(dim, str): |
| 192 | + factors_inferrable[dim] = True |
| 193 | + else: |
| 194 | + for factor in dim: |
| 195 | + if factor not in factors_inferrable.keys(): |
| 196 | + factors_inferrable[factor] = False |
| 197 | + |
| 198 | + # Check that factors in factor_sizes are used in the rule. |
| 199 | + for factor in factor_sizes: |
| 200 | + if factor not in factors_inferrable: |
| 201 | + raise ValueError( |
| 202 | + f"Factor {factor} is not used in the rule, but size is provided") |
| 203 | + |
| 204 | + # Check that factors that are used for a whole dimension aren't in |
| 205 | + # factor_sizes and factors that are never used for a whole dimension are |
| 206 | + # in factor_sizes. |
| 207 | + for factor, inferrable in factors_inferrable.items(): |
| 208 | + if factor not in factor_sizes and not inferrable: |
| 209 | + raise ValueError( |
| 210 | + f"Factor {factor} is only used in compound factors; must specify" |
| 211 | + " its size") |
| 212 | + if factor in factor_sizes and inferrable: |
| 213 | + raise ValueError( |
| 214 | + f"Factor {factor} represents a whole dimension; do not specify its" |
| 215 | + " size") |
| 216 | + |
| 217 | + self.factor_sizes = factor_sizes |
| 218 | + |
| 219 | + def __str__(self): |
| 220 | + return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})" |
| 221 | + |
| 222 | + def build( |
| 223 | + self, |
| 224 | + operand_types: list[ir.Type], |
| 225 | + result_types: list[ir.Type],) -> ir.Attribute: |
| 226 | + """Builds the MLIR representation for the sharding rule. |
| 227 | +
|
| 228 | + This is done by verifying that the rule is consistent with the types of |
| 229 | + the operation and converting the Einsum notation like string to |
| 230 | + OpShardingRuleAttr. |
| 231 | + """ |
| 232 | + if len(self.operands) != len(operand_types): |
| 233 | + raise ValueError( |
| 234 | + f"Sharding rule has {len(self.operands)} operands, but the operation" |
| 235 | + f" has {len(operand_types)} operands" |
| 236 | + ) |
| 237 | + if len(self.results) != len(result_types): |
| 238 | + raise ValueError( |
| 239 | + f"Sharding rule has {len(self.results)} results, but the operation" |
| 240 | + f" has {len(result_types)} results" |
| 241 | + ) |
| 242 | + |
| 243 | + factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() |
| 244 | + types = operand_types + result_types |
| 245 | + UNKNOWN = -1 # Representation for unknown factor size or factor index. |
| 246 | + |
| 247 | + def get_message_for_value(i): |
| 248 | + if i >= len(operand_types): |
| 249 | + return f"{i - len(operand_types)}th result" |
| 250 | + else: |
| 251 | + return f"{i}th operand" |
| 252 | + |
| 253 | + def get_rank_for_value(i): |
| 254 | + return ir.ShapedType(types[i]).rank |
| 255 | + |
| 256 | + def get_size_for_value_dim(i, j): |
| 257 | + return ir.ShapedType(types[i]).shape[j] |
| 258 | + |
| 259 | + def add_factor(factor, size): |
| 260 | + """Adds a factor to factors_to_indices_sizes. |
| 261 | +
|
| 262 | + `size` may be a dimensions size, a user specified factor size, or UNKNOWN |
| 263 | + if a factor is first used as in a compound factor and then used for a |
| 264 | + whole dimension. |
| 265 | + """ |
| 266 | + factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) |
| 267 | + if factor_index != UNKNOWN: |
| 268 | + # Not the first time seeing the factor. |
| 269 | + if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: |
| 270 | + factor_or_batching_dim = ( |
| 271 | + f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor |
| 272 | + else f"Batching dimension {factor[1:]}") |
| 273 | + raise ValueError( |
| 274 | + f"{factor_or_batching_dim} corresponds to two sizes:" |
| 275 | + f" {factor_size} and {size}") |
| 276 | + if size != UNKNOWN and factor_size == UNKNOWN: |
| 277 | + factors_to_indices_sizes[factor] = [factor_index, size] |
| 278 | + else: |
| 279 | + # First time seeing the factor. |
| 280 | + factor_index = len(factors_to_indices_sizes) |
| 281 | + factors_to_indices_sizes[factor] = [factor_index, size] |
| 282 | + |
| 283 | + def add_batching_dim_factor(batch_dim_order, factor_size): |
| 284 | + ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) |
| 285 | + add_factor(ellipsis_batch_dim_name, factor_size) |
| 286 | + |
| 287 | + def build_dim_mapping_for_compound_factors(i, j, factors): |
| 288 | + accumulated_size = 1 |
| 289 | + all_indices = [] |
| 290 | + for factor in factors: |
| 291 | + factor_index, factor_size = factors_to_indices_sizes[factor] |
| 292 | + accumulated_size *= factor_size |
| 293 | + all_indices.append(factor_index) |
| 294 | + |
| 295 | + dim_size = get_size_for_value_dim(i, j) |
| 296 | + if accumulated_size != dim_size: |
| 297 | + raise ValueError( |
| 298 | + f"{get_message_for_value(i)} actual size {dim_size} doesn't match" |
| 299 | + f" the size {accumulated_size} derived from the compound factors" |
| 300 | + f" {factors}") |
| 301 | + |
| 302 | + return sdy.DimMappingAttr.get(factor_indices=all_indices) |
| 303 | + |
| 304 | + # Add factors and their sizes in the order they appear in the rule, |
| 305 | + # including the batching dimensions represented by ellipsis. |
| 306 | + ellipsis_rank = None |
| 307 | + for i, value in enumerate(self.operands + self.results): |
| 308 | + if value and value[0] == _ELLIPSIS: |
| 309 | + has_ellipsis = True |
| 310 | + value = value[1:] |
| 311 | + else: |
| 312 | + has_ellipsis = False |
| 313 | + rule_rank = len(value) |
| 314 | + op_rank = get_rank_for_value(i) |
| 315 | + # The number of dimensions represented by ellipsis. |
| 316 | + current_ellipsis_rank = 0 |
| 317 | + if has_ellipsis and op_rank > rule_rank: |
| 318 | + current_ellipsis_rank = op_rank - rule_rank |
| 319 | + if has_ellipsis: |
| 320 | + if ellipsis_rank is None: |
| 321 | + ellipsis_rank = current_ellipsis_rank |
| 322 | + elif ellipsis_rank != current_ellipsis_rank: |
| 323 | + raise ValueError( |
| 324 | + "Ellipsis represents different number of leading dimensions" |
| 325 | + f" {ellipsis_rank} and {current_ellipsis_rank}") |
| 326 | + rule_rank += current_ellipsis_rank |
| 327 | + if rule_rank != op_rank: |
| 328 | + msg = get_message_for_value(i) |
| 329 | + raise ValueError( |
| 330 | + f"Sharding rule {msg} has rank {rule_rank}, but the operation" |
| 331 | + f" {msg} has rank {op_rank}") |
| 332 | + |
| 333 | + for j in range(current_ellipsis_rank): |
| 334 | + add_batching_dim_factor(j, get_size_for_value_dim(i, j)) |
| 335 | + |
| 336 | + for j, dim in enumerate(value): |
| 337 | + if isinstance(dim, str): |
| 338 | + add_factor( |
| 339 | + dim, get_size_for_value_dim(i, j + current_ellipsis_rank)) |
| 340 | + else: |
| 341 | + for factor in dim: |
| 342 | + add_factor(factor, self.factor_sizes.get(factor, UNKNOWN)) |
| 343 | + |
| 344 | + # Build the tensor mappings for each operand and result. |
| 345 | + tensor_mappings = [] |
| 346 | + for i, value in enumerate(self.operands + self.results): |
| 347 | + dim_mappings = [] |
| 348 | + |
| 349 | + if value and value[0] == _ELLIPSIS: |
| 350 | + value = value[1:] |
| 351 | + if ellipsis_rank is None: |
| 352 | + current_ellipsis_rank = 0 |
| 353 | + else: |
| 354 | + current_ellipsis_rank = ellipsis_rank |
| 355 | + else: |
| 356 | + current_ellipsis_rank = 0 |
| 357 | + |
| 358 | + for j in range(current_ellipsis_rank): |
| 359 | + dim_mappings.append( |
| 360 | + sdy.DimMappingAttr.get(factor_indices=[ |
| 361 | + factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) |
| 362 | + |
| 363 | + for j, dim in enumerate(value): |
| 364 | + if isinstance(dim, str): |
| 365 | + dim_mappings.append( |
| 366 | + sdy.DimMappingAttr.get( |
| 367 | + factor_indices=[factors_to_indices_sizes[dim][0]])) |
| 368 | + else: |
| 369 | + dim_mappings.append( |
| 370 | + build_dim_mapping_for_compound_factors( |
| 371 | + i, j + current_ellipsis_rank, dim)) |
| 372 | + |
| 373 | + tensor_mappings.append( |
| 374 | + sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) |
| 375 | + |
| 376 | + op_sharding_rule = sdy.OpShardingRuleAttr.get( |
| 377 | + factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], |
| 378 | + operand_mappings=tensor_mappings[0:len(operand_types)], |
| 379 | + result_mappings=tensor_mappings[len(operand_types):]) |
| 380 | + return op_sharding_rule |
0 commit comments