Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand Down Expand Up @@ -53,20 +53,28 @@
def _aten_lowering_pass(
*args: LoweringPassSignature,
index: Optional[int] = None,
**kwargs: Any,
) -> Union[
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
]:
"""Adds a lowering pass to the registry, at a specified index if desired

If no index is specified, the lowering pass is inserted at the end of the list

Additional keyword arguments can be passed to configure the lowering pass behavior.
These will be stored as metadata on the pass function.
"""

def add_lowering_pass(
lowering_pass: LoweringPassSignature,
) -> LoweringPassSignature:
# Store additional parameters as metadata on the function
if kwargs:
lowering_pass._lowering_pass_config = kwargs

ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
logger.debug(
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}"
)
return lowering_pass

Expand All @@ -81,7 +89,7 @@ def add_lowering_pass(
f"aten_lowering_pass decorator called with invalid arguments {args} "
"To specify an index to insert the pass, use the keyword 'index='"
)
# If no arguments are specified, the decorator was called with an index keyword
# If no arguments are specified, the decorator was called with keyword arguments
else:
return add_lowering_pass

Expand All @@ -95,6 +103,18 @@ def _remove_lowering_pass(*, index: int) -> None:
return


def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]:
"""Get the configuration parameters for a lowering pass function

Args:
lowering_pass: The lowering pass function

Returns:
Dictionary containing the configuration parameters, or empty dict if none
"""
return getattr(lowering_pass, "_lowering_pass_config", {})


def post_lowering(
gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings()
) -> torch.fx.GraphModule:
Expand Down
4 changes: 4 additions & 0 deletions tools/llm/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def get_model(args):
.eval()
.cuda()
)
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
else:
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)

if args.precision == "FP16":
model = model.to(torch.float16)
Expand Down
248 changes: 165 additions & 83 deletions tools/llm/torchtrt_ext/register_sdpa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import logging
import operator
from typing import Callable, Sequence, Tuple
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand All @@ -13,6 +13,7 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
from transformers import AutoConfig, Gemma3TextConfig

from .sdpa_converter import *

Expand All @@ -33,94 +34,175 @@
torch.ops.aten._scaled_dot_product_flash_attention.default,
}

from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
get_lowering_pass_config,
)

@_aten_lowering_pass
def replace_variants_of_sdpa(
gm: torch.fx.GraphModule, settings: CompilationSettings

def _process_sdpa_node(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
settings: CompilationSettings,
sliding_window_size: Optional[int] = None,
use_gqa: bool = False,
) -> torch.fx.GraphModule:
"""Replace scaled_dot_product_attention with an equivalent
implementation which can be accurately converted to TRT
"""
"""Helper function to process SDPA nodes with common logic."""

if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default:
if len(node.args) == 7:
(
query,
key,
value,
attn_mask,
compute_log_sumexp,
dropout_p,
is_causal,
) = node.args
elif len(node.args) == 5:
query, key, value, attn_mask, is_causal = node.args
dropout_p = 0.0
else:
raise ValueError(
f"Unexpected number of arguments for {node.target} in the graph"
)
elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default:
if len(node.args) == 6:
(
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
) = node.args
elif len(node.args) == 5:
query, key, value, dropout_p, is_causal = node.args
elif len(node.args) == 3:
query, key, value = node.args
dropout_p = 0.0
is_causal = True
else:
raise ValueError(
f"Unexpected number of arguments for {node.target} in the graph"
)
else:
return gm

# Always set causal to True and generate attn_mask inside the sdpa operator
attn_mask = None
is_causal = True
for node in gm.graph.nodes:
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
if (
node.target
== torch.ops.aten._scaled_dot_product_efficient_attention.default
):
if len(node.args) == 7:
(
query,
key,
value,
attn_bias,
compute_log_sumexp,
dropout_p,
is_causal,
) = node.args
elif len(node.args) == 5:
query, key, value, attn_mask, is_causal = node.args
dropout_p = 0.0

else:
raise ValueError(
f"Unexpected number of arguments for {node.target} in the graph"
)
elif (
node.target
== torch.ops.aten._scaled_dot_product_flash_attention.default
):
if len(node.args) == 6:
query, key, value, dropout_p, is_causal, return_debug_mask = (
node.args
)
if len(node.args) == 5:
query, key, value, dropout_p, is_causal = node.args
elif len(node.args) == 3:
query, key, value = node.args
dropout_p = 0.0
is_causal = True
else:
raise ValueError(
f"Unexpected number of arguments for {node.target} in the graph"
)
dropout_p = 0.0

logger.warning(
f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, "
f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}"
)

modified_input_args = (
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
)

# Create a new node with torch.nn.functional.scaled_dot_product_attention
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
torch.nn.functional.scaled_dot_product_attention,
args=modified_input_args,
kwargs={
"scale": node.kwargs.get("scale", None),
"use_fp32_acc": settings.use_fp32_acc,
"sliding_window_size": sliding_window_size,
},
)

# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
new_node.meta = copy.copy(node.meta)
# Check if there's a getitem node following this attention node
for user in list(node.users):
if user.op == "call_function" and user.target == operator.getitem:
# If the getitem is extracting the first element (the output tensor)
if user.args[1] == 0:
# Replace all uses of the getitem with the new attention node
user.replace_all_uses_with(new_node)
new_node.meta["val"] = new_node.meta["val"][0]
# Replace all uses of the original node with the new node
node.replace_all_uses_with(new_node)

gm.graph.erase_node(node)
return gm


def register_gemma3_sdpa_pass(index: int = 0, model_config: Any = None) -> None:
@_aten_lowering_pass(index=index, model_config=model_config)
def gemma3_sdpa_pass(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""SDPA pass specifically for Gemma3 models with sliding window attention."""
config = get_lowering_pass_config(gemma3_sdpa_pass)
sliding_window = None
layer_types = None
model_config = config.get("model_config", None)
if not isinstance(model_config, Gemma3TextConfig):
logger.warning(
f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
f"Expected Gemma3TextConfig, got {type(model_config)}, will use default SDPA replacement instead"
)
modified_input_args = (query, key, value, None, dropout_p, True)
# Create a new node with torch.nn.functional.scaled_dot_product_attention
# The input args is (query, key, value, is_causal). kwargs has scale
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
torch.nn.functional.scaled_dot_product_attention,
args=modified_input_args,
kwargs={
"scale": node.kwargs.get("scale", None),
"use_fp32_acc": settings.use_fp32_acc,
},
else:
sliding_window = getattr(model_config, "sliding_window", None)
layer_types = getattr(model_config, "layer_types", None)
logger.debug(
f"got Gemma3 config: sliding_window={sliding_window}, layer_types={layer_types}"
)

index = 0
for node in gm.graph.nodes:
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
sliding_window_size = None
if (
sliding_window is not None
and sliding_window > 0
and layer_types is not None
and index < len(layer_types)
):
if layer_types[index] == "sliding_attention":
sliding_window_size = sliding_window
index += 1

# Process the node
logger.debug(
f"Applying Gemma3-specific SDPA replacement with {node.name=}, {node.target=}, {sliding_window_size=}"
)
gm = _process_sdpa_node(gm, node, settings, sliding_window_size)

# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
new_node.meta = copy.copy(node.meta)
# Check if there's a getitem node following this attention node
for user in list(node.users):
if user.op == "call_function" and user.target == operator.getitem:
# If the getitem is extracting the first element (the output tensor)
if user.args[1] == 0:
# Replace all uses of the getitem with the new attention node
user.replace_all_uses_with(new_node)
new_node.meta["val"] = new_node.meta["val"][0]
# Replace all uses of the original node with the new node
node.replace_all_uses_with(new_node)

gm.graph.erase_node(node)

# Clean up the graph
clean_up_graph_after_modifications(gm)

logger.debug(
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
)
return gm
clean_up_graph_after_modifications(gm)
logger.debug("Applied Gemma3-specific SDPA replacement")
return gm


def register_default_sdpa_pass(index: int = 0, model_config: Any = None) -> None:
@_aten_lowering_pass(index=index, model_config=model_config)
def default_sdpa_pass(
gm: torch.fx.GraphModule,
settings: CompilationSettings,
) -> torch.fx.GraphModule:
"""Default SDPA pass for models without specific implementations."""

for node in gm.graph.nodes:
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
# Process the node with default logic
gm = _process_sdpa_node(gm, node, settings)

clean_up_graph_after_modifications(gm)
logger.debug("Applied default SDPA replacement")
return gm


# Global registry for SDPA passes
_SDPA_MAPPING: Dict[str, Callable] = {
"google/gemma-3-1b-it": register_gemma3_sdpa_pass,
"default": register_default_sdpa_pass,
}
Loading
Loading