|
18 | 18 |
|
19 | 19 | import math |
20 | 20 | import operator |
| 21 | +import re |
21 | 22 | from abc import ABC, abstractmethod |
22 | 23 | from collections import defaultdict |
23 | 24 | from enum import IntEnum |
24 | 25 | from functools import partial |
25 | | -from typing import Callable, DefaultDict, Dict, List, Literal, Optional, Set |
| 26 | +from typing import Any, Callable, DefaultDict, Dict, List, Literal, Optional, Set |
26 | 27 |
|
27 | 28 | import torch |
28 | 29 | import torch.nn as nn |
|
32 | 33 | from ...utils.logger import ad_logger |
33 | 34 | from ...utils.node_utils import ( |
34 | 35 | extract_param_names_from_lin_node, |
| 36 | + filtered_nodes, |
35 | 37 | identify_regions_between_residuals, |
36 | 38 | is_linear_op, |
37 | 39 | is_op, |
@@ -248,10 +250,200 @@ def apply(self, gm: GraphModule, node: Node) -> None: |
248 | 250 | class ShardingConfig(BaseModel): |
249 | 251 | """Configuration for sharding the model.""" |
250 | 252 |
|
| 253 | + rank: int = 0 |
| 254 | + world_size: int = 1 |
| 255 | + predefined_config: Dict[str, Any] = None |
251 | 256 | tp_transforms: List[TPShardingInfo] = Field(default_factory=list) |
252 | 257 | bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) |
253 | 258 | ep_transforms: List[EPShardingInfo] = Field(default_factory=list) |
254 | 259 |
|
| 260 | + def __init__(self, rank: int, world_size: int, sharding_config: Dict[str, Any] = None): |
| 261 | + super().__init__() |
| 262 | + self.rank = rank |
| 263 | + self.world_size = world_size |
| 264 | + self.predefined_config = sharding_config |
| 265 | + |
| 266 | + def create_sharding_from_config( |
| 267 | + self, gm: GraphModule, sharding_config: Dict[str, Any] = None |
| 268 | + ) -> None: |
| 269 | + """ |
| 270 | + Create sharding transformations from the predefined config. |
| 271 | + TODO: currently, it applies only to TP sharding. |
| 272 | + Args: |
| 273 | + gm: Graph module to apply transformations to |
| 274 | + sharding_config: Predefined sharding configuration |
| 275 | + """ |
| 276 | + if sharding_config is not None: |
| 277 | + self.predefined_config = sharding_config |
| 278 | + |
| 279 | + # check if config is valid. |
| 280 | + # 1. it is a Dict[str, str] |
| 281 | + # 2. the keys are of format "module.submodule.subsubmodule..." |
| 282 | + # 3. the wildcard "*" is allowed in the keys |
| 283 | + # 4. the allowed values are: |
| 284 | + # - "colwise" |
| 285 | + # - "rowwise" |
| 286 | + # - "sequence_parallel" |
| 287 | + # - "local_colwise" |
| 288 | + # - "local_rowwise" |
| 289 | + # - "local" |
| 290 | + # - "gather" |
| 291 | + # The following constraints are based on |
| 292 | + # https://github.com/huggingface/transformers/blob/d8e05951b8efd4880acca9a3f291e8b65841a86d/src/transformers/models/llama4/configuration_llama4.py#L249 |
| 293 | + |
| 294 | + if not isinstance(self.predefined_config, dict): |
| 295 | + ad_logger.warning("Sharding config is not a dictionary. Skipping.") |
| 296 | + return |
| 297 | + |
| 298 | + if "head_dim" not in self.predefined_config: |
| 299 | + ad_logger.warning("Sharding config does not contain head_dim. Skipping.") |
| 300 | + return |
| 301 | + head_dim = self.predefined_config["head_dim"] |
| 302 | + |
| 303 | + if "tp_plan" not in self.predefined_config: |
| 304 | + ad_logger.warning("Sharding config does not contain tp_plan. Skipping.") |
| 305 | + return |
| 306 | + tp_plan = self.predefined_config["tp_plan"] |
| 307 | + |
| 308 | + values = set(tp_plan.values()) |
| 309 | + allowed_values = { |
| 310 | + "colwise", |
| 311 | + "rowwise", |
| 312 | + "sequence_parallel", |
| 313 | + "local_colwise", |
| 314 | + "local_rowwise", |
| 315 | + "local_packed_rowwise", |
| 316 | + "local", |
| 317 | + "gather", |
| 318 | + } |
| 319 | + if not values.issubset(allowed_values): |
| 320 | + ad_logger.warning("Sharding config contains invalid values. Skipping.") |
| 321 | + return |
| 322 | + |
| 323 | + for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op): |
| 324 | + module_name = list(lin_node.meta["nn_module_stack"].keys())[-1] |
| 325 | + # use regex to find if module_name matches any of the keys in sharding_config |
| 326 | + for key in tp_plan.keys(): |
| 327 | + pattern_string = "*" + key + "*" |
| 328 | + # convert it to regex. Escape dots, replace * with .* |
| 329 | + # WARNING! A very hacky solution! First, we substitute * with unlikely character, e.g. @ |
| 330 | + # Then we escape dots, and finally we replace @ with .* |
| 331 | + pattern_string = pattern_string.replace("*", "@") |
| 332 | + pattern_regex = re.escape(pattern_string).replace("@", ".*") |
| 333 | + if re.match(pattern_regex, module_name): |
| 334 | + # we have a match. Get the config for this layer |
| 335 | + config = tp_plan[key] |
| 336 | + # TODO: @lucaslie: this is SUPER CONFUSING! |
| 337 | + # HF config uses "column" and "row" as-if Y = X @ W, so you have |
| 338 | + # all-gather after column, and all-reduce after row. |
| 339 | + # But since we assume Y = W @ X^T, we have a swapped column and row split. |
| 340 | + if config == "colwise": |
| 341 | + # if we are doing colwise split, we need to check if we are in |
| 342 | + # attention module. If so, we need to set min_local_shape to the |
| 343 | + # head_dim - otherwise, we would risk splitting the heads into smaller shards. |
| 344 | + # TODO: is there a better way to check if we are in attention module? |
| 345 | + attn_names = ["attention", "Attention", "attn", "Attn"] |
| 346 | + if any(attn_name in module_name for attn_name in attn_names): |
| 347 | + min_local_shape = head_dim |
| 348 | + else: |
| 349 | + min_local_shape = 1 |
| 350 | + self.tp_transforms.append( |
| 351 | + TPShardingInfo( |
| 352 | + target_node=lin_node.name, |
| 353 | + split_dim=SplitDimension.ROW, |
| 354 | + rank=self.rank, |
| 355 | + world_size=self.world_size, |
| 356 | + dist_op=None, |
| 357 | + min_local_shape=min_local_shape, |
| 358 | + ) |
| 359 | + ) |
| 360 | + elif config == "rowwise": |
| 361 | + self.tp_transforms.append( |
| 362 | + TPShardingInfo( |
| 363 | + target_node=lin_node.name, |
| 364 | + split_dim=SplitDimension.COLUMN, |
| 365 | + rank=self.rank, |
| 366 | + world_size=self.world_size, |
| 367 | + dist_op="all_reduce", |
| 368 | + min_local_shape=1, |
| 369 | + ) |
| 370 | + ) |
| 371 | + elif "sequence" in config: |
| 372 | + # TODO: Sequence parallelism is not supported yet. |
| 373 | + ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") |
| 374 | + elif "local" in config: |
| 375 | + # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. |
| 376 | + ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") |
| 377 | + elif "gather" in config: |
| 378 | + # Simple shard (row + all_gather) |
| 379 | + self.tp_transforms.append( |
| 380 | + TPShardingInfo( |
| 381 | + target_node=lin_node.name, |
| 382 | + split_dim=SplitDimension.ROW, |
| 383 | + rank=self.rank, |
| 384 | + world_size=self.world_size, |
| 385 | + dist_op="all_gather", |
| 386 | + min_local_shape=1, |
| 387 | + ) |
| 388 | + ) |
| 389 | + else: |
| 390 | + ad_logger.warning("Invalid sharding config. Skipping.") |
| 391 | + # after successful match, break the loop |
| 392 | + break |
| 393 | + |
| 394 | + def simple_shard_first_n_layers(self, n_layers: int) -> None: |
| 395 | + """ |
| 396 | + Simple shard the first n layers. |
| 397 | + 1. Take the existing config self.predefined_config, |
| 398 | + 2. Search for lines with wildcard "*", |
| 399 | + 3. Prepend to the top of the config list the same lines with "0, 1, ..., n_layers-1" |
| 400 | + # instead of "*". |
| 401 | + """ |
| 402 | + new_tp_plan = {} |
| 403 | + for layer_pattern, config in self.predefined_config["tp_plan"].items(): |
| 404 | + if "*" in layer_pattern: |
| 405 | + # Create new dict with first n_layers entries first |
| 406 | + |
| 407 | + for i in range(n_layers): |
| 408 | + new_tp_plan[layer_pattern.replace("*", str(i))] = "gather" |
| 409 | + |
| 410 | + # Add the default config after |
| 411 | + new_tp_plan[layer_pattern] = config |
| 412 | + |
| 413 | + self.predefined_config["tp_plan"] = new_tp_plan |
| 414 | + |
| 415 | + def simple_shard_last_n_layers(self, n_layers: int) -> None: |
| 416 | + """ |
| 417 | + Simple shard the last n layers. |
| 418 | + 1. Take the existing config self.predefined_config, |
| 419 | + 2. Search for lines with wildcard "*", |
| 420 | + 3. Prepend to the top of the config list the same lines with "0, 1, ..., n_layers-1" |
| 421 | + # instead of "*". |
| 422 | + """ |
| 423 | + new_tp_plan = {} |
| 424 | + num_layers = self.predefined_config["num_hidden_layers"] |
| 425 | + for layer_pattern, config in self.predefined_config["tp_plan"].items(): |
| 426 | + if "*" in layer_pattern: |
| 427 | + # Create new dict with first n_layers entries first |
| 428 | + |
| 429 | + for i in range(num_layers - n_layers, num_layers): |
| 430 | + new_tp_plan[layer_pattern.replace("*", str(i))] = "gather" |
| 431 | + |
| 432 | + # Add the default config after |
| 433 | + new_tp_plan[layer_pattern] = config |
| 434 | + self.predefined_config["tp_plan"] = new_tp_plan |
| 435 | + |
| 436 | + def simple_shard_attention_layers(self) -> None: |
| 437 | + """ |
| 438 | + If any key in tp_plan contains "attention", replace it with "gather" |
| 439 | + """ |
| 440 | + for layer_pattern, config in self.predefined_config["tp_plan"].items(): |
| 441 | + if any( |
| 442 | + attn_name in layer_pattern |
| 443 | + for attn_name in ["attention", "Attention", "attn", "Attn"] |
| 444 | + ): |
| 445 | + self.predefined_config["tp_plan"][layer_pattern] = "gather" |
| 446 | + |
255 | 447 |
|
256 | 448 | def sharding_transform_executor(gm: GraphModule, sharding_config: ShardingConfig) -> None: |
257 | 449 | """Apply transformations to the graph module. |
|
0 commit comments