Skip to content
Closed
1 change: 1 addition & 0 deletions .github/workflows/release-windows.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Release Windows wheels artifacts

on:
pull_request:
push:
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
Expand Down
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ repos:
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 25.1.0
# pin to a lower version for py3.9 compatibility
rev: 23.12.1
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand Down
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,11 @@ def index_dtype_validator(
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is not None and val.dtype not in (torch.int32, torch.int64):
if val is not None and val.dtype not in (
torch.int32,
torch.int64,
torch.bool,
):
return False
return True

Expand All @@ -424,6 +428,7 @@ def index_dtype_validator(
torch.ops.aten.index.Tensor,
capability_validator=index_dtype_validator,
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
@enforce_tensor_types(
{
Expand Down
69 changes: 66 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_numpy,
)
Expand Down Expand Up @@ -51,6 +50,71 @@ def select(
return layer.get_output(0)


def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool:
if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)):
return bool(tensor.dtype == torch.bool)
# when index is a node
else:
val = tensor.meta.get("val")
if val is not None and val.dtype is torch.bool:
return True

return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool


def expand_boolean_indices(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
new_indices = []
for i, ind in enumerate(indices):
if ind is not None and is_boolean_tensor(ind):
_LOGGER.debug(
f"Boolean index detected at position {i}, converting with nonzero()"
)
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")

nonzero_layer = ctx.net.add_non_zero(mask_tensor)
set_layer_name(
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
)
nonzero_indices = nonzero_layer.get_output(0)

# nonzero returns shape [N, dims], we need to extract dim i
if len(indices) == 1:
# x[mask] — 1D mask
to_squeeze = nonzero_indices
else:
# Advanced multi-axis mask: extract index i from shape [N, D]
gather_axis = 1 # dim index
gather_layer = ctx.net.add_gather(
nonzero_indices,
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
gather_axis,
)
set_layer_name(
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
)
to_squeeze = gather_layer.get_output(0)
squeeze_layer = ctx.net.add_shuffle(to_squeeze)
squeeze_layer.reshape_dims = (-1,)
set_layer_name(
squeeze_layer,
target,
name + f"_bool_mask_squeeze_{i}",
source_ir,
)
squeezed_index = squeeze_layer.get_output(0)
new_indices.append(squeezed_index)
else:
new_indices.append(ind)
return new_indices


def index(
ctx: ConversionContext,
target: Target,
Expand All @@ -61,13 +125,12 @@ def index(
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
"Determining whether aten.index constant-index optimization can be invoked"
)
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray))
for ind in indices
Expand Down
26 changes: 23 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand Down Expand Up @@ -53,20 +53,28 @@
def _aten_lowering_pass(
*args: LoweringPassSignature,
index: Optional[int] = None,
**kwargs: Any,
) -> Union[
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
]:
"""Adds a lowering pass to the registry, at a specified index if desired

If no index is specified, the lowering pass is inserted at the end of the list

Additional keyword arguments can be passed to configure the lowering pass behavior.
These will be stored as metadata on the pass function.
"""

def add_lowering_pass(
lowering_pass: LoweringPassSignature,
) -> LoweringPassSignature:
# Store additional parameters as metadata on the function
if kwargs:
lowering_pass._lowering_pass_config = kwargs

ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
logger.debug(
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}"
)
return lowering_pass

Expand All @@ -81,7 +89,7 @@ def add_lowering_pass(
f"aten_lowering_pass decorator called with invalid arguments {args} "
"To specify an index to insert the pass, use the keyword 'index='"
)
# If no arguments are specified, the decorator was called with an index keyword
# If no arguments are specified, the decorator was called with keyword arguments
else:
return add_lowering_pass

Expand All @@ -95,6 +103,18 @@ def _remove_lowering_pass(*, index: int) -> None:
return


def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]:
"""Get the configuration parameters for a lowering pass function

Args:
lowering_pass: The lowering pass function

Returns:
Dictionary containing the configuration parameters, or empty dict if none
"""
return getattr(lowering_pass, "_lowering_pass_config", {})


def post_lowering(
gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings()
) -> torch.fx.GraphModule:
Expand Down
47 changes: 46 additions & 1 deletion tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ class TestIndexConstantConverter(DispatchTestCase):
[None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])],
torch.randn(2, 4, 4, 2),
),
(
"mask_index_three_dim",
[None, torch.tensor([True, False]), None],
torch.randn(2, 2, 2),
),
(
"mask_index_two_dim",
[torch.tensor([True, False])],
torch.randn(2, 2),
),
(
# covers multi axis and discontinuous indices
"mask_index_multi_axis",
[
None,
torch.tensor([True, False]), # axis 1
None,
torch.tensor([True, False]), # axis 3
],
torch.randn(2, 4, 4, 2),
),
]
)
def test_index_constant(self, _, index, input):
Expand Down Expand Up @@ -168,7 +189,31 @@ def forward(self, input):
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)
self.run_test_with_dynamic_shape(
TestModule(), input_specs, use_dynamo_tracer=True
)


class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
def test_index_input_non_dynamic_index_dynamic(self):
class TestIndexWithRuntimeIndex(torch.nn.Module):
def forward(self, x):
mask = x > 0
idx = torch.nonzero(mask, as_tuple=True)
return torch.ops.aten.index.Tensor(x, idx)

input_specs = [
Input(
min_shape=(2, 2),
opt_shape=(2, 2),
max_shape=(8, 8),
dtype=torch.float32,
),
]
# In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True
self.run_test_with_dynamic_shape(
TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True
)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions tools/llm/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def get_model(args):
.eval()
.cuda()
)
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
else:
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)

if args.precision == "FP16":
model = model.to(torch.float16)
Expand Down
2 changes: 1 addition & 1 deletion tools/llm/static_cache_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def insert_kv_slicing_before_sdpa(
sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (
attn_mask,
dropout_p,
True,
is_causal,
)

# kv_cache_for_graph.extend([k_node, v_node])
Expand Down
Loading
Loading