diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 61337ae3f42..791436d050f 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -157,6 +157,12 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "If False, auto-detect and use column+row (all_reduce) sharding when possible.", ) + use_sharding_from_factory: bool = Field( + default=False, + description="If True, use sharding from the model config (if present). " + "If False, run heuristics to detect sharding.", + ) + compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = ( Field( default="torch-compile", diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 42a30402537..b97f54a2f68 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -2,6 +2,7 @@ import copy from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Callable, Dict, Optional, Type import torch @@ -12,6 +13,13 @@ from ..utils.logger import ad_logger +class ShardingConfigSource(Enum): + """Enum for factory source.""" + + HUGGINGFACE = "huggingface" + UNKNOWN = "unknown" + + class ModelFactory(ABC): """An interface to return and correctly initialize a model from a desired source. @@ -38,6 +46,7 @@ def __init__( self.max_seq_len = max_seq_len self._prefetched_model_path: Optional[str] = None self._prefetched_tokenizer_path: Optional[str] = None + self._sharding_config: Dict[str, Any] = {} @property def model(self) -> Optional[str]: @@ -96,6 +105,10 @@ def get_quant_config(self) -> Dict: """Returns the quantization config for this model or None if not quantized.""" return {} + def get_sharding_config(self) -> Dict: + """Returns the sharding config for this model.""" + return self._sharding_config + def get_cache_config(self) -> CacheConfig: """Return the cache configuration for the model. @@ -104,6 +117,14 @@ def get_cache_config(self) -> CacheConfig: """ return CacheConfig() + def get_sharding_config_source(self) -> ShardingConfigSource: + """Return the source of the model factory. + + Returns: + The source identifier for this model factory. + """ + return ShardingConfigSource.UNKNOWN + def init_tokenizer(self) -> Optional[Any]: """Initialize the tokenizer for the model. diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index fc37c1e557a..624d53172cb 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -30,7 +30,7 @@ from ..custom_ops.attention_interface import CacheConfig from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger -from .factory import ModelFactory, ModelFactoryRegistry +from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource @contextmanager @@ -174,12 +174,25 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module: if hasattr(model, "post_init"): model.post_init() + # if present, initialize sharding config. We need head_dim for colwise sharding. + self._set_sharding_config(model.config) + # patch forward method model.forward = types.MethodType(self._simple_forward, model) model.eval() return model + def _set_sharding_config(self, model_config: PretrainedConfig): + """Set the sharding config for the model.""" + self._sharding_config["head_dim"] = 1 + if hasattr(model_config, "base_model_tp_plan"): + self._sharding_config["tp_plan"] = model_config.base_model_tp_plan + if hasattr(model_config, "head_dim"): + self._sharding_config["head_dim"] = model_config.head_dim + if hasattr(model_config, "num_hidden_layers"): + self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers + def get_quant_config(self) -> Dict: return self._quant_config or {} @@ -196,6 +209,14 @@ def get_cache_config(self): kv_cache_dtype = None return CacheConfig(dtype=kv_cache_dtype) + def get_sharding_config_source(self) -> ShardingConfigSource: + """Return the source of the model factory. + + Returns: + The source identifier for this model factory. + """ + return ShardingConfigSource.HUGGINGFACE + def init_tokenizer(self) -> Optional[Any]: """Initialize the tokenizer—either a custom name or the model's default.""" if self.tokenizer is None: @@ -363,6 +384,19 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]: }, } + def _set_sharding_config(self, model_config: PretrainedConfig): + """Override the sharding config for the model with text_config.""" + super()._set_sharding_config(model_config) + + if hasattr(model_config, "text_config"): + text_config = model_config.text_config + if hasattr(text_config, "base_model_tp_plan"): + self._sharding_config["tp_plan"] = text_config.base_model_tp_plan + if hasattr(text_config, "head_dim"): + self._sharding_config["head_dim"] = text_config.head_dim + if hasattr(text_config, "num_hidden_layers"): + self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers + @property def automodel_from_config(self): return AutoModelForImageTextToText.from_config diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py index d7ed5918a49..549a244b97f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py @@ -18,20 +18,23 @@ import math import operator +import re from abc import ABC, abstractmethod from collections import defaultdict from enum import IntEnum from functools import partial -from typing import Callable, DefaultDict, Dict, List, Literal, Optional, Set +from typing import Any, Callable, DefaultDict, Dict, List, Literal, Optional, Set import torch import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field from torch.fx import GraphModule, Node +from ...models.factory import ModelFactory, ShardingConfigSource from ...utils.logger import ad_logger from ...utils.node_utils import ( extract_param_names_from_lin_node, + filtered_nodes, identify_regions_between_residuals, is_linear_op, is_op, @@ -44,8 +47,12 @@ class SplitDimension(IntEnum): """Enum for tensor split dimensions in sharding.""" - ROW = 0 # Split along rows (first dimension) - COLUMN = 1 # Split along columns (second dimension) + # NOTE: The names COLUMN/ROW reflect the hugging face + # base_tp_plan sharding notation, but since we assume Y = W @ X^T, + # when splitting weight matrix W^T across columns, the actual split + # is over dimension 0 + COLUMN = 0 # Split along columns (second dimension) + ROW = 1 # Split along rows (first dimension) class ShardingTransformInfo(BaseModel, ABC): @@ -90,16 +97,16 @@ class TPShardingInfo(ShardingTransformInfo): def validate(self, gm: GraphModule = None, node: Node = None) -> bool: """Validate the transformation configuration.""" if self.dist_op is not None: - if self.split_dim == SplitDimension.ROW: + if self.split_dim == SplitDimension.COLUMN: if self.dist_op == "all_reduce": ad_logger.warning( - f"Row split is only supported for all_gather. Skipping {self}." + f"Column split is only supported for all_gather. Skipping {self}." ) return False - if self.split_dim == SplitDimension.COLUMN: + if self.split_dim == SplitDimension.ROW: if self.dist_op == "all_gather": ad_logger.warning( - f"Column split is only supported for all_reduce. Skipping {self}." + f"Row split is only supported for all_reduce. Skipping {self}." ) return False return True @@ -248,10 +255,260 @@ def apply(self, gm: GraphModule, node: Node) -> None: class ShardingConfig(BaseModel): """Configuration for sharding the model.""" + factory_source: ShardingConfigSource + rank: int + world_size: int + _predefined_config: Optional[Dict[str, Any]] = None + simple_shard_only: bool = False + use_sharding_from_factory: bool = False tp_transforms: List[TPShardingInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) + def __init__( + self, + rank: int, + world_size: int, + factory_source: ShardingConfigSource = ShardingConfigSource.UNKNOWN, + sharding_config: Dict[str, Any] = None, + simple_shard_only: bool = False, + use_sharding_from_factory: bool = False, + ): + super().__init__( + factory_source=factory_source, + rank=rank, + world_size=world_size, + _predefined_config=sharding_config, + simple_shard_only=simple_shard_only, + use_sharding_from_factory=use_sharding_from_factory, + ) + + # Pydantic does not support setting private fields directly. + self._predefined_config = sharding_config + # Validate the config after initialization + if self._predefined_config is not None: + self.validate_config() + + def validate_config(self) -> bool: + if self.factory_source != ShardingConfigSource.HUGGINGFACE: + ad_logger.warning( + "Sharding config is is currently only " + "supported for HuggingFace. Skipping." + ) + # invalidate the config + self._predefined_config = {} + return False + + if not isinstance(self._predefined_config, dict): + ad_logger.warning("Sharding config is not a dictionary. Skipping.") + # invalidate the config + self._predefined_config = {} + return False + + if "head_dim" not in self._predefined_config: + ad_logger.warning("Sharding config does not contain head_dim. Skipping.") + # invalidate the config + self._predefined_config = {} + return False + + if "tp_plan" not in self._predefined_config: + ad_logger.warning("Sharding config does not contain tp_plan. Skipping.") + # invalidate the config + self._predefined_config = {} + return False + tp_plan = self._predefined_config["tp_plan"] + + values = set(tp_plan.values()) + allowed_values = { + "colwise", # row split and no collective + "rowwise", # column split and all-reduce + "gather", # simple shard (row + all_gather) + # TODO: remaining values are not supported yet. + # They require hybrid EP+TP and/or SP support. + # "sequence_parallel", # sequence parallelism + # "local_colwise", + # "local_rowwise", + # "local_packed_rowwise", + # "local", + } + if not values.issubset(allowed_values): + ad_logger.warning("Sharding config contains invalid values. Skipping.") + # invalidate the config + self._predefined_config = {} + return False + return True + + def get_predefined_config(self) -> Dict[str, Any]: + return self._predefined_config + + +def detect_sharding_from_factory_config( + gm: GraphModule, + sharding_config: ShardingConfig, +) -> None: + """ + Create sharding transformations from the predefined config. + TODO: currently, it applies only to TP sharding. + Args: + gm: Graph module to apply transformations to + sharding_config: Predefined sharding configuration + """ + # check if config is valid. + # 1. it is a Dict[str, str] + # 2. the keys are of format "module.submodule.subsubmodule..." + # 3. the wildcard "*" is allowed in the keys + # 4. the allowed values are: + # - "colwise" + # - "rowwise" + # - "sequence_parallel" + # - "local_colwise" + # - "local_rowwise" + # - "local" + # - "gather" + # The following constraints are based on + # https://github.com/huggingface/transformers/blob/d8e05951b8efd4880acca9a3f291e8b65841a86d/src/transformers/models/llama4/configuration_llama4.py#L249 + + factory_config = sharding_config.get_predefined_config() + head_dim = factory_config["head_dim"] + tp_plan = factory_config["tp_plan"] + + rank, world_size = sharding_config.rank, sharding_config.world_size + + # If the node is inside the attention module, we need to set min_local_shape to the + # head_dim - otherwise, we would risk splitting the heads into smaller shards. + # TODO: is there a better way to check if we are in attention module? + attn_names = [ + "attention", + "Attention", + "attn", + "Attn", + "q_proj", + "k_proj", + "v_proj", + "o_proj", + ] + + for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op): + # use node's weight name to get the module name + module_name = lin_node.args[1].target + + if any(attn_name in module_name for attn_name in attn_names): + min_local_shape = head_dim + else: + min_local_shape = 1 + + # use regex to find if module_name matches any of the keys in sharding_config + for key in tp_plan.keys(): + pattern_string = "*" + key + "*" + # convert it to regex. Escape dots, replace * with .* + # First, we substitute * with an unlikely character, e.g. @ + # Then we escape dots, and finally we replace @ with .* + pattern_string = pattern_string.replace("*", "@") + pattern_regex = re.escape(pattern_string).replace("@", ".*") + if re.match(pattern_regex, module_name): + # we have a match. Get the config for this layer + config = tp_plan[key] + if config == "colwise": + sharding_config.tp_transforms.append( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + ) + ) + elif config == "rowwise": + sharding_config.tp_transforms.append( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + min_local_shape=min_local_shape, + ) + ) + elif "sequence" in config: + # TODO: Sequence parallelism is not supported yet. + ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") + elif "local" in config: + # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. + ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") + elif "gather" in config: + # Simple shard (row + all_gather) + sharding_config.tp_transforms.append( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) + ) + else: + ad_logger.warning("Invalid sharding config. Skipping.") + # after successful match, break the loop + break + + +def simple_shard_first_n_layers(sharding_config: ShardingConfig, n_layers: int) -> None: + """ + Simple shard the first n layers. + 1. Take the existing config sharding_config.predefined_config, + 2. Search for lines with wildcard "*", + 3. Prepend to the top of the config list the same lines with "0, 1, ..., n_layers-1" + # instead of "*". + """ + new_tp_plan = {} + factory_config = sharding_config.get_predefined_config() + for layer_pattern, config in factory_config["tp_plan"].items(): + if "*" in layer_pattern: + # Create new dict with first n_layers entries first + + for i in range(n_layers): + new_tp_plan[layer_pattern.replace("*", str(i))] = "gather" + + # Add the default config after + new_tp_plan[layer_pattern] = config + + sharding_config._predefined_config["tp_plan"] = new_tp_plan + + +def simple_shard_last_n_layers(sharding_config: ShardingConfig, n_layers: int) -> None: + """ + Simple shard the last n layers. + 1. Take the existing config sharding_config.predefined_config, + 2. Search for lines with wildcard "*", + 3. Prepend to the top of the config list the same lines with "0, 1, ..., n_layers-1" + # instead of "*". + """ + new_tp_plan = {} + factory_config = sharding_config.get_predefined_config() + num_layers = factory_config["num_hidden_layers"] + for layer_pattern, config in factory_config["tp_plan"].items(): + if "*" in layer_pattern: + # Create new dict with first n_layers entries first + + for i in range(num_layers - n_layers, num_layers): + new_tp_plan[layer_pattern.replace("*", str(i))] = "gather" + + # Add the default config after + new_tp_plan[layer_pattern] = config + sharding_config._predefined_config["tp_plan"] = new_tp_plan + + +def simple_shard_attention_layers(sharding_config: ShardingConfig) -> None: + """ + If any key in tp_plan contains "attention", replace it with "gather" + """ + for layer_pattern, config in sharding_config._predefined_config["tp_plan"].items(): + if any( + attn_name in layer_pattern for attn_name in ["attention", "Attention", "attn", "Attn"] + ): + sharding_config._predefined_config["tp_plan"][layer_pattern] = "gather" + def sharding_transform_executor(gm: GraphModule, sharding_config: ShardingConfig) -> None: """Apply transformations to the graph module. @@ -460,7 +717,7 @@ def _append_simple_shard( tp_shards.append( TPShardingInfo( target_node=n.name, - split_dim=SplitDimension.ROW, + split_dim=SplitDimension.COLUMN, rank=rank, world_size=world_size, dist_op="all_gather", @@ -470,12 +727,47 @@ def _append_simple_shard( sharding_config.tp_transforms.extend(tp_shards) -def detect_column_row_shard( +def detect_sharding( gm: GraphModule, - rank: int, + factory: ModelFactory, + local_rank: int, world_size: int, + simple_shard_only: bool, + use_sharding_from_factory: bool, +) -> ShardingConfig: + sharding_config = ShardingConfig( + local_rank, + world_size, + factory.get_sharding_config_source(), + factory.get_sharding_config(), + simple_shard_only, + use_sharding_from_factory, + ) + + if ( + sharding_config.use_sharding_from_factory + and len(sharding_config.get_predefined_config()) > 0 + ): + ad_logger.info("Applying sharding from config") + detect_sharding_from_factory_config(gm, sharding_config) + return sharding_config + + ad_logger.info("Running autodeploy sharding heuristics") + # run TP sharding across ranks + detect_column_row_shard(gm, sharding_config) + + # run EP sharding across ranks + detect_ep_shard(gm, sharding_config) + + # run BMM sharding across ranks + detect_dp_bmm_shard(gm, sharding_config) + + return sharding_config + + +def detect_column_row_shard( + gm: GraphModule, sharding_config: ShardingConfig, - simple_shard_only: bool = False, ) -> None: """A transformation to apply sharding to the model following tensor parallelism. @@ -495,12 +787,15 @@ def detect_column_row_shard( """ ad_logger.debug("Before sharding graph: " + str(gm)) + rank, world_size = sharding_config.rank, sharding_config.world_size if world_size < 2: - ad_logger.info("Skipping sharding for single device") + ad_logger.info("Skipping TP sharding for single device") return assert isinstance(gm, GraphModule), "Expecting GraphModule" + ad_logger.info("Running TP sharding detection") + # find boundary nodes of regions we want to shard boundary_nodes = identify_regions_between_residuals(gm) @@ -571,7 +866,7 @@ def detect_column_row_shard( num_shards += 1 - if simple_shard_only: + if sharding_config.simple_shard_only: ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") _append_simple_shard(nodes_linear, rank, world_size, sharding_config) continue @@ -648,9 +943,7 @@ def detect_column_row_shard( ad_logger.info(f"Found {num_shards} TP shards") -def detect_dp_bmm_shard( - gm: GraphModule, rank: int, world_size: int, sharding_config: ShardingConfig -) -> None: +def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None: """A transformation to apply sharding to batched matrix multiplications in the graph. We'll shard the BMM nodes by slicing the batch dimension of input tensors into world_size number of slices. @@ -660,9 +953,9 @@ def detect_dp_bmm_shard( We'll also assume that the inputs to BMM are broadcasted across the devices already. """ ad_logger.debug("Before sharding graph: " + str(gm)) - + rank, world_size = sharding_config.rank, sharding_config.world_size if world_size < 2: - ad_logger.info("Skipping sharding for single device") + ad_logger.info("Skipping DP BMM sharding for single device") return assert isinstance(gm, GraphModule), "Expecting GraphModule" @@ -728,13 +1021,12 @@ def detect_dp_bmm_shard( ad_logger.info(f"Found {num_bmm_shards} BMM shards") -def detect_ep_shard( - gm: GraphModule, rank: int, world_size: int, sharding_config: ShardingConfig -) -> None: +def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None: ad_logger.debug("Before sharding graph: " + str(gm)) + rank, world_size = sharding_config.rank, sharding_config.world_size if world_size < 2: - ad_logger.info("Skipping sharding for single device") + ad_logger.info("Skipping EP sharding for single device") return assert isinstance(gm, GraphModule), "Expecting GraphModule" diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index 3844ce4d312..306730c6e20 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -15,10 +15,7 @@ from ..utils.logger import ad_logger from ._graph import canonicalize_graph, lift_to_meta, move_to_device from .library import ( - ShardingConfig, - detect_column_row_shard, - detect_dp_bmm_shard, - detect_ep_shard, + detect_sharding, eliminate_redundant_transposes, fuse_allreduce_residual_rmsnorm, fuse_collectives, @@ -95,20 +92,15 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: # see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528 optimize_rope(egm) - # TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config. - sharding_config = ShardingConfig() - - # run TP sharding across ranks - detect_column_row_shard( - egm, local_rank, world_size, sharding_config, self.ad_config.simple_shard_only + sharding_config = detect_sharding( + egm, + self.factory, + local_rank, + world_size, + self.ad_config.simple_shard_only, + self.ad_config.use_sharding_from_factory, ) - # run EP sharding across ranks - detect_ep_shard(egm, local_rank, world_size, sharding_config) - - # run BMM sharding across ranks - detect_dp_bmm_shard(egm, local_rank, world_size, sharding_config) - sharding_transform_executor(egm, sharding_config) # let's run a shape propagation pass to update the graph with correct meta values for diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 48f06c70e60..0841fbb59b3 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -2,7 +2,7 @@ import operator from dataclasses import dataclass -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import Callable, Iterable, List, Optional, Tuple, Union, overload import torch from torch._ops import OpOverload, OpOverloadPacket @@ -206,35 +206,63 @@ def is_op(node: Node, ops: Union[OperatorLike, Iterable[OperatorLike]]) -> bool: return is_match +@overload +def filtered_nodes(nodes: Iterable[Node], target: Callable[[Node], bool]) -> Iterable[Node]: + """Overload for filtering with a callable target function.""" + ... + + +@overload def filtered_nodes( nodes: Iterable[Node], ops: Union[OperatorLike, Iterable[OperatorLike]] ) -> Iterable[Node]: - """Iterate over nodes that are filtered by the given operations. + """Overload for filtering with operation(s).""" + ... + + +def filtered_nodes( + nodes: Iterable[Node], + target: Union[Callable[[Node], bool], Union[OperatorLike, Iterable[OperatorLike]]] = None, + ops: Union[OperatorLike, Iterable[OperatorLike]] = None, +) -> Iterable[Node]: + """Iterate over nodes that are filtered by the given operations or target function. This utility function simplifies the common pattern of iterating through nodes - and filtering by operation type. + and filtering by operation type or custom function. Args: nodes: Iterable of nodes to filter (e.g., gm.graph.nodes) - ops: Operation(s) to match against + target: Either a callable function that takes a Node and returns bool, + or operation(s) to match against (deprecated, use ops parameter) + ops: Operation(s) to match against (preferred over target for operations) Yields: - Node: Nodes that match the given operations + Node: Nodes that match the given operations or target function Example: - # Instead of: - for node in gm.graph.nodes: - if not is_op(node, torch.ops.aten.linear): - continue + # Using callable function: + for node in filtered_nodes(gm.graph.nodes, is_linear_op): + # process node + + # Using operations: + for node in filtered_nodes(gm.graph.nodes, ops=torch.ops.aten.linear): # process node - # Use: - for node in filtered_nodes(gm.graph.nodes, torch.ops.aten.linear): + # Using multiple operations: + for node in filtered_nodes(gm.graph.nodes, ops=[torch.ops.aten.linear, torch.ops.aten.bmm]): # process node """ - for node in nodes: - if is_op(node, ops): - yield node + # Handle the case where target is a callable function + if callable(target) and not isinstance(target, (OpOverloadPacket, OpOverload)): + for node in nodes: + if target(node): + yield node + else: + # Handle the case where target or ops contains operations + operations = ops if ops is not None else target + for node in nodes: + if is_op(node, operations): + yield node def is_linear_op(node: Node, include_quantization: bool = False) -> bool: diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index ab135aa28a1..4372fee51c5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -70,8 +70,11 @@ def _get_expected_num_params(num_p_og: int) -> int: return num_params def transform_func(gm) -> None: - sharding_config = ShardingConfig() - detect_dp_bmm_shard(gm, rank, world_size, sharding_config) + sharding_config = ShardingConfig( + rank, + world_size, + ) + detect_dp_bmm_shard(gm, sharding_config) sharding_transform_executor(gm, sharding_config) # now run the test @@ -118,8 +121,8 @@ def _run_pattern_detection_job( ) # get detected transformations - sharding_config = ShardingConfig() - detect_dp_bmm_shard(gm, rank, world_size, sharding_config) + sharding_config = ShardingConfig(rank, world_size) + detect_dp_bmm_shard(gm, sharding_config) detected_transformations = sharding_config.bmm_transforms # Run pattern detection test diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 19cce483297..37224a7eeca 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -40,8 +40,8 @@ def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: return n_gate + expected_expert def transform_func(gm) -> None: - sharding_config = ShardingConfig() - detect_ep_shard(gm, rank, world_size, sharding_config) + sharding_config = ShardingConfig(rank, world_size) + detect_ep_shard(gm, sharding_config) sharding_transform_executor(gm, sharding_config) op_expected = torch.ops.auto_deploy.torch_dist_all_reduce @@ -89,8 +89,8 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> ) # get detected transformations - sharding_config = ShardingConfig() - detect_ep_shard(gm, rank, world_size, sharding_config) + sharding_config = ShardingConfig(rank, world_size) + detect_ep_shard(gm, sharding_config) detected_transformations = sharding_config.ep_transforms # Run pattern detection test diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 9e33bef4a91..5c76763845b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -14,13 +14,45 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transformations.library import ( ShardingConfig, + ShardingConfigSource, SplitDimension, TPShardingInfo, detect_column_row_shard, + detect_sharding_from_factory_config, sharding_transform_executor, ) from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op +base_model_tp_plan = { + "q_proj": "colwise", + "k_proj": "colwise", + "v_proj": "colwise", + "o_proj": "rowwise", + "gate_proj": "colwise", + "up_proj": "colwise", + "down_proj": "rowwise", + "linear1": "colwise", + "linear2": "rowwise", + "linear": "gather", + # "input_layernorm.weight": "sequence_parallel", + # "post_attention_layernorm.weight": "sequence_parallel", + # "norm.weight": "sequence_parallel", + # "shared_expert.gate_proj": "local_colwise", + # "shared_expert.up_proj": "local_colwise", + # "shared_expert.down_proj": "local_rowwise", + # "experts.gate_up_proj": "local_packed_rowwise", + # "experts.down_proj": "local_colwise", + # "experts": "local", + "feed_forward": "gather", + "self": "gather", + "weight": "gather", +} + +predefined_config = { + "head_dim": 8, + "tp_plan": base_model_tp_plan, +} + class GQA_Block(nn.Module): def __init__( @@ -83,6 +115,7 @@ def _run_job( model_cls: nn.Module, dist_op_expected: str, bias: bool, + from_config: bool, rank: int, world_size: int, ) -> None: @@ -129,6 +162,7 @@ def _get_expected_num_params(num_p_og: int) -> int: num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size else: num_params = num_p_og // world_size + num_update + print(f"\n\nnum_p_og: {num_p_og}, num_params: {num_params}") return num_params def verify_local_weight_sizes(gm) -> bool: @@ -147,8 +181,19 @@ def verify_local_weight_sizes(gm) -> bool: op_expected = getattr(torch.ops.auto_deploy, dist_op_expected) def transform_func(gm) -> None: - sharding_config = ShardingConfig() - detect_column_row_shard(gm, rank, world_size, sharding_config) + sharding_config = ShardingConfig( + rank=rank, + world_size=world_size, + factory_source=ShardingConfigSource.HUGGINGFACE, + sharding_config=predefined_config, + simple_shard_only=False, + use_sharding_from_factory=from_config, + ) + if from_config: + if world_size > 1: + detect_sharding_from_factory_config(gm, sharding_config) + else: + detect_column_row_shard(gm, sharding_config) sharding_transform_executor(gm, sharding_config) def combined_graph_check(gm) -> bool: @@ -174,6 +219,7 @@ def _run_pattern_detection_job( bias: bool, rank: int, world_size: int, + from_config: bool, ) -> None: # init model and input batch_size = 4 @@ -200,7 +246,7 @@ def _run_pattern_detection_job( gm = torch_export_to_gm(model, args=(x,), clone=True) expected_transformations = [] # if world_size == 1, no sharding transformations should be detected - if world_size > 1: + if world_size > 1 or from_config: if model_cls == GQA_Block: min_local_shape = num_features // num_heads for node in gm.graph.nodes: @@ -210,10 +256,10 @@ def _run_pattern_detection_job( # for O layer, we expect: # dim = 1, add_dist = True if "o_proj" in node.args[1].name: - dim = SplitDimension.COLUMN + dim = SplitDimension.ROW dist_op = "all_reduce" else: - dim = SplitDimension.ROW + dim = SplitDimension.COLUMN dist_op = None expected_transformations.append( TPShardingInfo( @@ -231,10 +277,10 @@ def _run_pattern_detection_job( # linear1 should be sharded on dim=0, add_dist=False, min_local_shape=1 # linear2 should be sharded on dim=1, add_dist=True, min_local_shape=1 if "linear1" in node.args[1].name: - dim = SplitDimension.ROW + dim = SplitDimension.COLUMN dist_op = None else: - dim = SplitDimension.COLUMN + dim = SplitDimension.ROW dist_op = "all_reduce" expected_transformations.append( TPShardingInfo( @@ -253,7 +299,7 @@ def _run_pattern_detection_job( expected_transformations.append( TPShardingInfo( target_node=node.name, - split_dim=SplitDimension.ROW, # Simple shard uses dim=0 + split_dim=SplitDimension.COLUMN, # Simple shard uses dim=0 rank=rank, world_size=world_size, dist_op="all_gather", @@ -262,8 +308,18 @@ def _run_pattern_detection_job( ) # get detected transformations - sharding_config = ShardingConfig() - detect_column_row_shard(gm, rank, world_size, sharding_config) + sharding_config = ShardingConfig( + rank=rank, + world_size=world_size, + factory_source=ShardingConfigSource.HUGGINGFACE, + sharding_config=predefined_config, + simple_shard_only=False, + use_sharding_from_factory=from_config, + ) + if from_config: + detect_sharding_from_factory_config(gm, sharding_config) + else: + detect_column_row_shard(gm, sharding_config) detected_transformations = sharding_config.tp_transforms # Run pattern detection test @@ -272,6 +328,7 @@ def _run_pattern_detection_job( @pytest.mark.parametrize("device_count", get_device_counts()) @pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("from_config", [False, True]) @pytest.mark.parametrize( "model_cls, dist_op_expected", ( @@ -280,15 +337,22 @@ def _run_pattern_detection_job( (GQA_Block, "torch_dist_all_reduce"), ), ) -def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, device_count: int): +def test_sharding( + model_cls: Type[nn.Module], + dist_op_expected: str, + bias: bool, + device_count: int, + from_config: bool, +): dist_common.spawn_multiprocess_job( - job=partial(_run_job, model_cls, dist_op_expected, bias), + job=partial(_run_job, model_cls, dist_op_expected, bias, from_config), size=device_count, ) @pytest.mark.parametrize("world_size", [1, 8]) @pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("from_config", [False, True]) @pytest.mark.parametrize( "model_cls, dist_op_expected", ( @@ -298,11 +362,15 @@ def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, ), ) def test_sharding_pattern_detection( - model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, world_size: int + model_cls: Type[nn.Module], + dist_op_expected: str, + bias: bool, + world_size: int, + from_config: bool, ): """Test pattern detection logic without distributed execution. This test verifies only the pattern detection logic with provided world_size. No need to run distributed job, can be run on single process. """ - _run_pattern_detection_job(model_cls, bias, 0, world_size) + _run_pattern_detection_job(model_cls, bias, 0, world_size, from_config)