Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand Down Expand Up @@ -53,20 +53,28 @@
def _aten_lowering_pass(
*args: LoweringPassSignature,
index: Optional[int] = None,
**kwargs: Any,
) -> Union[
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
]:
"""Adds a lowering pass to the registry, at a specified index if desired

If no index is specified, the lowering pass is inserted at the end of the list

Additional keyword arguments can be passed to configure the lowering pass behavior.
These will be stored as metadata on the pass function.
"""

def add_lowering_pass(
lowering_pass: LoweringPassSignature,
) -> LoweringPassSignature:
# Store additional parameters as metadata on the function
if kwargs:
lowering_pass._lowering_pass_config = kwargs

ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
logger.debug(
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}"
)
return lowering_pass

Expand All @@ -81,7 +89,7 @@ def add_lowering_pass(
f"aten_lowering_pass decorator called with invalid arguments {args} "
"To specify an index to insert the pass, use the keyword 'index='"
)
# If no arguments are specified, the decorator was called with an index keyword
# If no arguments are specified, the decorator was called with keyword arguments
else:
return add_lowering_pass

Expand All @@ -95,6 +103,18 @@ def _remove_lowering_pass(*, index: int) -> None:
return


def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]:
"""Get the configuration parameters for a lowering pass function

Args:
lowering_pass: The lowering pass function

Returns:
Dictionary containing the configuration parameters, or empty dict if none
"""
return getattr(lowering_pass, "_lowering_pass_config", {})


def post_lowering(
gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings()
) -> torch.fx.GraphModule:
Expand Down
1 change: 1 addition & 0 deletions tools/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ We have officially verified support for the following models:
| LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct<br>meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes |
| Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct<br>Qwen/Qwen2.5-1.5B-Instruct<br>Qwen/Qwen2.5-4B-Instruct<br>Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes |
| Qwen 3 | Qwen/Qwen3-0.6B<br>Qwen/Qwen3-1.7B<br>Qwen/Qwen3-4B<br>Qwen/Qwen3-8B | FP16, FP32 | Yes |
| Gemma 3 | google/gemma-3-1b-it | FP16, FP32 | Yes |


### Usage
Expand Down
5 changes: 5 additions & 0 deletions tools/llm/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def get_model(args):
.eval()
.cuda()
)
# register SDPA variant for the model
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
else:
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)

if args.precision == "FP16":
model = model.to(torch.float16)
Expand Down
Loading
Loading