From 060266aab50d0097bcbf9c241a81d0ded50109d9 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 22 Aug 2025 12:45:19 -0700 Subject: [PATCH 01/12] initial check in --- .pre-commit-config.yaml | 3 +- tools/llm/torchtrt_ext/register_sdpa.py | 17 ++++-- tools/llm/torchtrt_ext/sdpa_converter.py | 70 +++++++++++++----------- 3 files changed, 51 insertions(+), 39 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7b91eec34..c12918fdc9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,8 @@ repos: hooks: - id: ruff - repo: https://github.com/psf/black - rev: 25.1.0 + # pin to a lower version for py3.9 compatibility + rev: 23.12.1 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 90a00a5798..0a3efa858f 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -72,9 +72,14 @@ def replace_variants_of_sdpa( == torch.ops.aten._scaled_dot_product_flash_attention.default ): if len(node.args) == 6: - query, key, value, dropout_p, is_causal, return_debug_mask = ( - node.args - ) + ( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + ) = node.args if len(node.args) == 5: query, key, value, dropout_p, is_causal = node.args elif len(node.args) == 3: @@ -87,11 +92,11 @@ def replace_variants_of_sdpa( ) logger.warning( - f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." + f"This current version of SDPA converter only supports attn_mask = {attn_mask}, dropout_p = {dropout_p} and is_causal = {is_causal} configuration. This could cause issues with accuracy for models with different configurations." ) - modified_input_args = (query, key, value, None, dropout_p, True) + modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, is_causal). kwargs has scale + # The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale with gm.graph.inserting_after(node): new_node = gm.graph.call_function( torch.nn.functional.scaled_dot_product_attention, diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 47083c7b48..a84008db47 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -66,7 +66,11 @@ def scaled_dot_product_attention( # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - is_causal = True + + assert ( + not is_causal and attn_mask is None + ), "either is_causal or attn_mask should be set" + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) query_dtype = query.dtype @@ -134,37 +138,39 @@ def scaled_dot_product_attention( L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2) if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) - - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) + if is_causal: + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + else: + attn_bias = attn_mask scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias From 1d4d4397ceb0730e87a58774032614b2a430edad Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 22 Aug 2025 17:08:38 -0700 Subject: [PATCH 02/12] add kv cache support(not working yet) --- tools/llm/static_cache_v2.py | 20 +++++----- tools/llm/torchtrt_ext/register_sdpa.py | 7 ++-- tools/llm/torchtrt_ext/sdpa_converter.py | 51 +++++++++++++++++++----- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py index 4634b79a52..1bd2e420c9 100644 --- a/tools/llm/static_cache_v2.py +++ b/tools/llm/static_cache_v2.py @@ -233,16 +233,18 @@ def insert_kv_slicing_before_sdpa( q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args incoming_key, incoming_value = incoming_keys_values[idx] # For keys - new_current_key_node, new_incoming_key_cache_node = ( - create_kv_cache_update_nodes( - gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input - ) + ( + new_current_key_node, + new_incoming_key_cache_node, + ) = create_kv_cache_update_nodes( + gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input ) # For values - new_current_value_node, new_incoming_value_cache_node = ( - create_kv_cache_update_nodes( - gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input - ) + ( + new_current_value_node, + new_incoming_value_cache_node, + ) = create_kv_cache_update_nodes( + gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input ) # Store the KV cache nodes for the current SDPA node @@ -254,7 +256,7 @@ def insert_kv_slicing_before_sdpa( sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + ( attn_mask, dropout_p, - True, + is_causal, ) # kv_cache_for_graph.extend([k_node, v_node]) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 0a3efa858f..da5363c6f8 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -41,9 +41,10 @@ def replace_variants_of_sdpa( """Replace scaled_dot_product_attention with an equivalent implementation which can be accurately converted to TRT """ - attn_mask = None - is_causal = True + for node in gm.graph.nodes: + attn_mask = None + is_causal = False if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: if ( node.target @@ -54,7 +55,7 @@ def replace_variants_of_sdpa( query, key, value, - attn_bias, + attn_mask, compute_log_sumexp, dropout_p, is_causal, diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index a84008db47..e78dad79dd 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -27,24 +27,50 @@ def tril( name: str, row: TRTTensor, col: TRTTensor, + sliding_window_size: Optional[int] = None, ) -> TRTTensor: row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - col_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 ) - - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor + ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask + + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size + ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) return mask @@ -68,8 +94,11 @@ def scaled_dot_product_attention( source_ir = SourceIR.ATEN assert ( - not is_causal and attn_mask is None - ), "either is_causal or attn_mask should be set" + is_causal or attn_mask is not None + ), "at least one of is_causal or attn_mask should be set" + assert is_causal ^ ( + attn_mask is not None + ), "Exactly one of is_causal or attn_mask must be set" # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) From 22197e2e8b8ff0be4c9fe613ce8306fc3f02f4d5 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 25 Aug 2025 13:21:04 -0700 Subject: [PATCH 03/12] add if_conditional, not working though. --- tools/llm/static_cache_v2.py | 18 ++--- tools/llm/torchtrt_ext/register_sdpa.py | 35 +++++++++ tools/llm/torchtrt_ext/sdpa_converter.py | 97 ++++++++++++++++++++++-- 3 files changed, 134 insertions(+), 16 deletions(-) diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py index 1bd2e420c9..6d6c4409e9 100644 --- a/tools/llm/static_cache_v2.py +++ b/tools/llm/static_cache_v2.py @@ -233,18 +233,16 @@ def insert_kv_slicing_before_sdpa( q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args incoming_key, incoming_value = incoming_keys_values[idx] # For keys - ( - new_current_key_node, - new_incoming_key_cache_node, - ) = create_kv_cache_update_nodes( - gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input + new_current_key_node, new_incoming_key_cache_node = ( + create_kv_cache_update_nodes( + gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input + ) ) # For values - ( - new_current_value_node, - new_incoming_value_cache_node, - ) = create_kv_cache_update_nodes( - gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input + new_current_value_node, new_incoming_value_cache_node = ( + create_kv_cache_update_nodes( + gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input + ) ) # Store the KV cache nodes for the current SDPA node diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index da5363c6f8..9aafb2ad11 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -129,4 +129,39 @@ def replace_variants_of_sdpa( logger.debug( "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" ) + add_attn_mask_as_output = False + if add_attn_mask_as_output: + add_one_attn_mask_as_output(gm) return gm + + +# try to add one of the attn_mask as output, so that I can actually see the shape and value in the generation phase. +def add_one_attn_mask_as_output(gm: torch.fx.GraphModule): + import torch.utils._pytree as pytree + from cache_utils import create_random_output_tensors + + attn_mask_node = None + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.nn.functional.scaled_dot_product_attention + ): + attn_mask_node = node.args[3] + break + + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + current_outputs = output_node.args[0] + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + (attn_mask_node,) + else: + new_outputs = (current_outputs, attn_mask_node) + output_node.args = new_outputs + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + gm = clean_up_graph_after_modifications(gm) + new_output_tensors = create_random_output_tensors(new_outputs) + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + return gm \ No newline at end of file diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index e78dad79dd..3def60a8ca 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -95,10 +95,10 @@ def scaled_dot_product_attention( assert ( is_causal or attn_mask is not None - ), "at least one of is_causal or attn_mask should be set" + ), "One of is_causal or attn_mask should be set" assert is_causal ^ ( attn_mask is not None - ), "Exactly one of is_causal or attn_mask must be set" + ), "Exactly one of is_causal or attn_mask should be set" # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) @@ -198,12 +198,97 @@ def scaled_dot_product_attention( attn_bias = impl.unary.log( ctx, target, source_ir, name + "_log", one_minus_temp_mask ) + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) else: - attn_bias = attn_mask + use_if_conditional = False + if not use_if_conditional: + # works in non cache scenario, but in kv cache, got the following error: + # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: [ELEMENTWISE]-[aten_ops.scaled_dot_product_attention]-[model.layers.0.self_attn/scaled_dot_product_attention_attn_mask_add]: dimensions not compatible for elementwise. Broadcast has incompatible dimensions: 5 != 71 && 5 != 1 && 71 != 1.) + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_mask_add", mm, attn_mask + ) + else: + if_option = "if_conditional_subgraph" # if_conditional_subgraph or if_conditional or if_conditional_input + if if_option == "if_conditional_subgraph": + # reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L46 + # if_conditional_subgraph is not working, got the following error: + # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block + # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block) - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias - ) + need_mask = impl.elementwise.eq( + ctx, target, source_ir, name + "_eq", L, S + ) + # if I do not squeeze, it will throw the error: condition must be a scalar tensor + condition = impl.squeeze.squeeze( + ctx, target, source_ir, name + "_unsqueeze", need_mask, 0 + ) + if_layer = ctx.net.add_if_conditional() + if_layer.set_condition(condition) + cond_input1 = if_layer.add_input(mm) + cond_input2 = if_layer.add_input(attn_mask) + + true_input = impl.elementwise.add( + ctx, + target, + source_ir, + name + "_attn_bias_add", + cond_input1.get_output(0), + cond_input2.get_output(0), + ) + false_input = cond_input1.get_output(0) + output_layer = if_layer.add_output(true_input, false_input) + scaled_add_attn_bias = output_layer.get_output(0) + elif if_option == "if_conditional_input": + # reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L17 + # if_conditional_input is not working, got the following error: + # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block + # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block) + + need_mask = impl.elementwise.eq( + ctx, target, source_ir, name + "_eq", L, S + ) + # if I do not squeeze, it will throw the error: condition must be a scalar tensor + condition = impl.squeeze.squeeze( + ctx, target, source_ir, name + "_unsqueeze", need_mask, 0 + ) + if_layer = ctx.net.add_if_conditional() + if_layer.set_condition(condition) + true_input = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_mask + ) + false_input = mm + true_cond_input = if_layer.add_input(true_input) + false_cond_input = if_layer.add_input(false_input) + output_layer = if_layer.add_output( + true_cond_input.get_output(0), false_cond_input.get_output(0) + ) + scaled_add_attn_bias = output_layer.get_output(0) + elif if_option == "if_conditional": + # reference: https://github.com/pytorch/TensorRT/blob/535c6a8341a3258a9c311406a9af50eb3c68c5a6/examples/dynamo/llm/cache_utils.py#L15-L44 + # if_conditional is not working, got the following error: + # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block + # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block) + + need_mask = impl.elementwise.eq( + ctx, target, source_ir, name + "_eq", L, S + ) + # if I do not squeeze, it will throw the error: condition must be a scalar tensor + condition = impl.squeeze.squeeze( + ctx, target, source_ir, name + "_unsqueeze", need_mask, 0 + ) + if_layer = ctx.net.add_if_conditional() + if_layer.set_condition(condition) + true_input = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_mask + ) + false_input = mm + output_layer = if_layer.add_output( + true_input.get_output(0), false_input.get_output(0) + ) + scaled_add_attn_bias = output_layer.get_output(0) + softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False From 2b93ac11d3cda830ff02c2556023dd8a8f54134c Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 17 Jul 2025 11:13:45 -0700 Subject: [PATCH 04/12] Index converter dynamic cases fix --- .../dynamo/conversion/aten_ops_converters.py | 4 ++- tests/py/dynamo/conversion/test_index_aten.py | 26 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index abf721198b..65923c7dac 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -420,7 +420,9 @@ def index_dtype_validator( @dynamo_tensorrt_converter( - torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator + torch.ops.aten.index.Tensor, + capability_validator=index_dtype_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 8e21f945dc..fc4a70b1ff 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -168,7 +168,31 @@ def forward(self, input): dtype=torch.float32, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs) + self.run_test_with_dynamic_shape( + TestModule(), input_specs, use_dynamo_tracer=True + ) + + +class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase): + def test_index_input_non_dynamic_index_dynamic(self): + class TestIndexWithRuntimeIndex(torch.nn.Module): + def forward(self, x): + mask = x > 0 + idx = torch.nonzero(mask, as_tuple=True) + return torch.ops.aten.index.Tensor(x, idx) + + input_specs = [ + Input( + min_shape=(2, 2), + opt_shape=(2, 2), + max_shape=(8, 8), + dtype=torch.float32, + ), + ] + # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True + self.run_test_with_dynamic_shape( + TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True + ) if __name__ == "__main__": From 879410ba818cd2fc56027e214f64db4da5e7f765 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 31 Jul 2025 15:53:21 -0700 Subject: [PATCH 05/12] support for boolean indices --- .../dynamo/conversion/aten_ops_converters.py | 6 +- .../dynamo/conversion/impl/select.py | 62 ++++++++++++++++++- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 65923c7dac..178caa17c1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -414,7 +414,11 @@ def index_dtype_validator( for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype not in (torch.int32, torch.int64): + if val is not None and val.dtype not in ( + torch.int32, + torch.int64, + torch.bool, + ): return False return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index c4d44a07ea..10a8332538 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -51,6 +51,65 @@ def select( return layer.get_output(0) +def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: + if isinstance(tensor, (TRTTensor)): + val = tensor.meta.get("val") + if val is not None and val.dtype is torch.bool: + return True + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool + + +def expand_boolean_indices( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + for i, ind in enumerate(indices): + if ind is not None and is_boolean_tensor(ind): + _LOGGER.debug( + f"Boolean index detected at position {i}, converting with nonzero()" + ) + + mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") + + nonzero_layer = ctx.net.add_non_zero(mask_tensor) + set_layer_name( + nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir + ) + nonzero_indices = nonzero_layer.get_output(0) + + # nonzero returns shape [N, dims], we need to extract dim i + if len(indices) == 1: + # x[mask] — 1D mask + squeeze_layer = ctx.net.add_shuffle(nonzero_indices) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_nonzero_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + ind = squeezed_index + else: + # Advanced multi-axis mask: extract index i from shape [N, D] + gather_axis = 1 # dim index + gather_layer = ctx.net.add_gather( + nonzero_indices, + get_trt_tensor(ctx, i, name + f"_dim_index_{i}"), + gather_axis, + ) + set_layer_name( + gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir + ) + extracted_index = gather_layer.get_output(0) + ind = extracted_index + return indices + + def index( ctx: ConversionContext, target: Target, @@ -61,8 +120,6 @@ def index( ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # check if the input is dynamic - dynamic_shape = has_dynamic_shape(input.shape) # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( @@ -76,6 +133,7 @@ def index( # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") From d2778f5057a13612540ed124327a605c3572c642 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 21 Aug 2025 17:28:39 -0700 Subject: [PATCH 06/12] mask test cases and correction --- .../dynamo/conversion/aten_ops_converters.py | 1 + .../dynamo/conversion/impl/select.py | 22 +++++++++++++------ tests/py/dynamo/conversion/test_index_aten.py | 10 +++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 178caa17c1..e9be9c9b89 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -427,6 +427,7 @@ def index_dtype_validator( torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator, supports_dynamic_shapes=True, + requires_output_allocator=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 10a8332538..85de72893d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -14,7 +14,6 @@ cast_trt_tensor, get_positive_dim, get_trt_tensor, - has_dynamic_shape, set_layer_name, to_numpy, ) @@ -52,10 +51,14 @@ def select( def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: - if isinstance(tensor, (TRTTensor)): + if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)): + return bool(tensor.dtype == torch.bool) + # when index is a node + else: val = tensor.meta.get("val") if val is not None and val.dtype is torch.bool: return True + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool @@ -67,12 +70,12 @@ def expand_boolean_indices( input: TRTTensor, indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], ) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + new_indices = [] for i, ind in enumerate(indices): if ind is not None and is_boolean_tensor(ind): _LOGGER.debug( f"Boolean index detected at position {i}, converting with nonzero()" ) - mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") nonzero_layer = ctx.net.add_non_zero(mask_tensor) @@ -93,7 +96,7 @@ def expand_boolean_indices( source_ir, ) squeezed_index = squeeze_layer.get_output(0) - ind = squeezed_index + new_indices.append(squeezed_index) else: # Advanced multi-axis mask: extract index i from shape [N, D] gather_axis = 1 # dim index @@ -106,8 +109,13 @@ def expand_boolean_indices( gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir ) extracted_index = gather_layer.get_output(0) - ind = extracted_index - return indices + squeeze_layer = ctx.net.add_shuffle(extracted_index) + squeeze_layer.reshape_dims = (-1,) + squeezed_index = squeeze_layer.get_output(0) + new_indices.append(squeezed_index) + else: + new_indices.append(ind) + return new_indices def index( @@ -125,6 +133,7 @@ def index( _LOGGER.debug( "Determining whether aten.index constant-index optimization can be invoked" ) + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) is_numpy = all( isinstance(ind, (torch.Tensor, np.ndarray)) for ind in indices @@ -133,7 +142,6 @@ def index( # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None - indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index fc4a70b1ff..5aa44c02c4 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -71,6 +71,16 @@ class TestIndexConstantConverter(DispatchTestCase): [None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])], torch.randn(2, 4, 4, 2), ), + ( + "mask_index_three_dim", + [None, torch.tensor([True, False]), None], + torch.randn(2, 2, 2), + ), + ( + "mask_index_two_dim", + [torch.tensor([True, False])], + torch.randn(2, 2), + ), ] ) def test_index_constant(self, _, index, input): From 76f7f55be9283c0052d27e6eeb555d6b6e2bb15f Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 22 Aug 2025 16:26:04 -0700 Subject: [PATCH 07/12] adding the discontinuous mask indices case --- tests/py/dynamo/conversion/test_index_aten.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 5aa44c02c4..f7278f84a6 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -81,6 +81,17 @@ class TestIndexConstantConverter(DispatchTestCase): [torch.tensor([True, False])], torch.randn(2, 2), ), + ( + # covers multi axis and discontinuous indices + "mask_index_multi_axis", + [ + None, + torch.tensor([[True, False, False, True]]), # axis 1 + None, + torch.tensor([True, False]), # axis 3 + ], + torch.randn(2, 4, 4, 2), + ), ] ) def test_index_constant(self, _, index, input): From 51a60a5e25e8fd884f889dea88787166052777c8 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 22 Aug 2025 16:39:33 -0700 Subject: [PATCH 08/12] unifying the squeee layer --- .../dynamo/conversion/impl/select.py | 27 +++++++++---------- tests/py/dynamo/conversion/test_index_aten.py | 2 +- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 85de72893d..ded50519ad 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -87,16 +87,7 @@ def expand_boolean_indices( # nonzero returns shape [N, dims], we need to extract dim i if len(indices) == 1: # x[mask] — 1D mask - squeeze_layer = ctx.net.add_shuffle(nonzero_indices) - squeeze_layer.reshape_dims = (-1,) - set_layer_name( - squeeze_layer, - target, - name + f"_bool_nonzero_squeeze_{i}", - source_ir, - ) - squeezed_index = squeeze_layer.get_output(0) - new_indices.append(squeezed_index) + to_squeeze = nonzero_indices else: # Advanced multi-axis mask: extract index i from shape [N, D] gather_axis = 1 # dim index @@ -108,11 +99,17 @@ def expand_boolean_indices( set_layer_name( gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir ) - extracted_index = gather_layer.get_output(0) - squeeze_layer = ctx.net.add_shuffle(extracted_index) - squeeze_layer.reshape_dims = (-1,) - squeezed_index = squeeze_layer.get_output(0) - new_indices.append(squeezed_index) + to_squeeze = gather_layer.get_output(0) + squeeze_layer = ctx.net.add_shuffle(to_squeeze) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_mask_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + new_indices.append(squeezed_index) else: new_indices.append(ind) return new_indices diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index f7278f84a6..e069fab263 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -86,7 +86,7 @@ class TestIndexConstantConverter(DispatchTestCase): "mask_index_multi_axis", [ None, - torch.tensor([[True, False, False, True]]), # axis 1 + torch.tensor([True, False]), # axis 1 None, torch.tensor([True, False]), # axis 3 ], From f9b66ad22ce0b5f2a03815a7e47acf61adb10925 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 28 Aug 2025 16:17:36 -0700 Subject: [PATCH 09/12] register lowering pass with model config --- .../lowering/passes/_aten_lowering_pass.py | 26 +- tools/llm/run_llm.py | 1 + tools/llm/torchtrt_ext/register_sdpa.py | 288 ++++++++++-------- tools/llm/torchtrt_ext/sdpa_converter.py | 14 +- tools/llm/utils.py | 1 - 5 files changed, 195 insertions(+), 135 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 516c371e48..a9fe77533d 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -55,20 +55,28 @@ def _aten_lowering_pass( *args: LoweringPassSignature, index: Optional[int] = None, + **kwargs: Any, ) -> Union[ LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature] ]: """Adds a lowering pass to the registry, at a specified index if desired If no index is specified, the lowering pass is inserted at the end of the list + + Additional keyword arguments can be passed to configure the lowering pass behavior. + These will be stored as metadata on the pass function. """ def add_lowering_pass( lowering_pass: LoweringPassSignature, ) -> LoweringPassSignature: + # Store additional parameters as metadata on the function + if kwargs: + lowering_pass._lowering_pass_config = kwargs + ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) logger.debug( - f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" + f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return lowering_pass @@ -83,7 +91,7 @@ def add_lowering_pass( f"aten_lowering_pass decorator called with invalid arguments {args} " "To specify an index to insert the pass, use the keyword 'index='" ) - # If no arguments are specified, the decorator was called with an index keyword + # If no arguments are specified, the decorator was called with keyword arguments else: return add_lowering_pass @@ -97,6 +105,18 @@ def _remove_lowering_pass(*, index: int) -> None: return +def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]: + """Get the configuration parameters for a lowering pass function + + Args: + lowering_pass: The lowering pass function + + Returns: + Dictionary containing the configuration parameters, or empty dict if none + """ + return getattr(lowering_pass, "_lowering_pass_config", {}) + + def post_lowering( gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings() ) -> torch.fx.GraphModule: diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..cda0b1a96c 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -58,6 +58,7 @@ def get_model(args): .eval() .cuda() ) + register_sdpa.register_sdpa_pass_with_model_config(model_config=model.config) if args.precision == "FP16": model = model.to(torch.float16) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 9aafb2ad11..570b50dd7a 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -13,6 +13,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) +from transformers import Gemma3TextConfig from .sdpa_converter import * @@ -34,134 +35,175 @@ } -@_aten_lowering_pass -def replace_variants_of_sdpa( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Replace scaled_dot_product_attention with an equivalent - implementation which can be accurately converted to TRT +def register_sdpa_pass_with_model_config(index: int = 0, model_config=None): """ + Register the SDPA replacement pass with a specific model configuration. - for node in gm.graph.nodes: - attn_mask = None - is_causal = False - if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - if ( - node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ): - if len(node.args) == 7: - ( - query, - key, - value, - attn_mask, - compute_log_sumexp, - dropout_p, - is_causal, - ) = node.args - elif len(node.args) == 5: - query, key, value, attn_mask, is_causal = node.args - dropout_p = 0.0 - - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - elif ( - node.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ): - if len(node.args) == 6: - ( - query, - key, - value, - dropout_p, - is_causal, - return_debug_mask, - ) = node.args - if len(node.args) == 5: - query, key, value, dropout_p, is_causal = node.args - elif len(node.args) == 3: - query, key, value = node.args - dropout_p = 0.0 - is_causal = True - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) + Args: + model_config: The model configuration object (e.g., from transformers.AutoConfig) + index: Position in the lowering pass list (default: 0) + + Example: + from transformers import AutoConfig + config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium") + register_sdpa_pass_with_model_config(config) + """ + from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, + _remove_lowering_pass, + ) + # Create a new pass with the model configuration + @_aten_lowering_pass(index=index, model_config=model_config) + def replace_variants_of_sdpa_with_config( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + """Replace scaled_dot_product_attention with model-specific configuration""" + + # Access the model configuration from the decorator parameters + from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + get_lowering_pass_config, + ) + + config = get_lowering_pass_config(replace_variants_of_sdpa_with_config) + + model_config = config.get("model_config", None) + layer_types = [] + sliding_window = None + # Extract model-specific parameters + if model_config is not None: + if isinstance(model_config, Gemma3TextConfig): + sliding_window = getattr(model_config, "sliding_window", None) + layer_types = getattr(model_config, "layer_types", None) + logger.info(f"Model config: {sliding_window=} {layer_types=}") + else: logger.warning( - f"This current version of SDPA converter only supports attn_mask = {attn_mask}, dropout_p = {dropout_p} and is_causal = {is_causal} configuration. This could cause issues with accuracy for models with different configurations." + "No model configuration provided, using default SDPA replacement behavior" ) - modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal) - # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale - with gm.graph.inserting_after(node): - new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, - args=modified_input_args, - kwargs={ - "scale": node.kwargs.get("scale", None), - "use_fp32_acc": settings.use_fp32_acc, - }, + index = 0 + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + sliding_window_size = None + if ( + sliding_window is not None + and sliding_window > 0 + and layer_types is not None + and index < len(layer_types) + ): + if layer_types[index] == "sliding_attention": + sliding_window_size = sliding_window + index += 1 + + if ( + node.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ): + if len(node.args) == 7: + ( + query, + key, + value, + attn_mask, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + elif ( + node.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ): + if len(node.args) == 6: + ( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + ) = node.args + if len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + + # always set_causal to True and generate attn_mask inside the sdpa operator, do not use the attn_mask from the transformers. + attn_mask = None + is_causal = True + dropout_p = 0.0 + + logger.warning( + f"This current version of SDPA converter only supports {attn_mask=}, {dropout_p=} and {is_causal=} and {sliding_window_size=} configuration. This could cause issues with accuracy for models with different configurations." ) + modified_input_args = ( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + ) + # Create a new node with torch.nn.functional.scaled_dot_product_attention + # The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + "sliding_window_size": sliding_window_size, + }, + ) + + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if ( + user.op == "call_function" + and user.target == operator.getitem + ): + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta["val"] = new_node.meta["val"][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + + # Clean up the graph + clean_up_graph_after_modifications(gm) + + if model_config: + logger.debug( + f"Replaced variants of scaled_dot_product_attention for {getattr(model_config, 'model_type', 'unknown')} model" + ) + else: + logger.debug( + "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" + ) + add_attn_mask_as_output = False + if add_attn_mask_as_output: + add_one_attn_mask_as_output(gm) + return gm - # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. - new_node.meta = copy.copy(node.meta) - # Check if there's a getitem node following this attention node - for user in list(node.users): - if user.op == "call_function" and user.target == operator.getitem: - # If the getitem is extracting the first element (the output tensor) - if user.args[1] == 0: - # Replace all uses of the getitem with the new attention node - user.replace_all_uses_with(new_node) - new_node.meta["val"] = new_node.meta["val"][0] - # Replace all uses of the original node with the new node - node.replace_all_uses_with(new_node) - - gm.graph.erase_node(node) - - # Clean up the graph - clean_up_graph_after_modifications(gm) - - logger.debug( - "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" + logger.info( + f"Registered SDPA pass with model config: {getattr(model_config, 'model_type', 'unknown')}" ) - add_attn_mask_as_output = False - if add_attn_mask_as_output: - add_one_attn_mask_as_output(gm) - return gm - - -# try to add one of the attn_mask as output, so that I can actually see the shape and value in the generation phase. -def add_one_attn_mask_as_output(gm: torch.fx.GraphModule): - import torch.utils._pytree as pytree - from cache_utils import create_random_output_tensors - - attn_mask_node = None - for node in gm.graph.nodes: - if ( - node.op == "call_function" - and node.target == torch.nn.functional.scaled_dot_product_attention - ): - attn_mask_node = node.args[3] - break - - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - current_outputs = output_node.args[0] - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + (attn_mask_node,) - else: - new_outputs = (current_outputs, attn_mask_node) - output_node.args = new_outputs - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - gm = clean_up_graph_after_modifications(gm) - new_output_tensors = create_random_output_tensors(new_outputs) - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - return gm \ No newline at end of file + return replace_variants_of_sdpa_with_config diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 3def60a8ca..bd2dcead68 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -93,15 +93,12 @@ def scaled_dot_product_attention( scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - assert ( - is_causal or attn_mask is not None - ), "One of is_causal or attn_mask should be set" - assert is_causal ^ ( - attn_mask is not None - ), "Exactly one of is_causal or attn_mask should be set" + assert is_causal == True, "is_causal should be set to True" # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) + sliding_window_size = kwargs.get("sliding_window_size", None) + query_dtype = query.dtype if scale is None: @@ -169,7 +166,9 @@ def scaled_dot_product_attention( S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) if is_causal: # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + tril_tensor = tril( + ctx, target, source_ir, name + "_tril", L, S, sliding_window_size + ) temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor @@ -288,7 +287,6 @@ def scaled_dot_product_attention( true_input.get_output(0), false_input.get_output(0) ) scaled_add_attn_bias = output_layer.get_output(0) - softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..c56aa9b490 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok num_tokens_generated = 0 kv_cache = get_zeroed_dynamic_cache_inputs(model) last_position_id = position_ids[-1, -1].item() - breakpoint() while num_tokens_generated < num_output_tokens: is_generate = False if input_seq.shape[1] > 1 else True position_ids = ( From befc0b93a87862e4e8a10a0f82e3c65c8ee4b51a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 28 Aug 2025 17:37:04 -0700 Subject: [PATCH 10/12] clean up code --- tools/llm/torchtrt_ext/sdpa_converter.py | 152 +++++------------------ 1 file changed, 32 insertions(+), 120 deletions(-) diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 5513ffa967..f7a7203f38 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -164,130 +164,42 @@ def scaled_dot_product_attention( L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2) if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - if is_causal: - # generate the mask tensor - tril_tensor = tril( - ctx, target, source_ir, name + "_tril", L, S, sliding_window_size - ) - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) - - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias - ) - else: - use_if_conditional = False - if not use_if_conditional: - # works in non cache scenario, but in kv cache, got the following error: - # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: [ELEMENTWISE]-[aten_ops.scaled_dot_product_attention]-[model.layers.0.self_attn/scaled_dot_product_attention_attn_mask_add]: dimensions not compatible for elementwise. Broadcast has incompatible dimensions: 5 != 71 && 5 != 1 && 71 != 1.) - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_mask_add", mm, attn_mask - ) - else: - if_option = "if_conditional_subgraph" # if_conditional_subgraph or if_conditional or if_conditional_input - if if_option == "if_conditional_subgraph": - # reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L46 - # if_conditional_subgraph is not working, got the following error: - # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block - # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block) - - need_mask = impl.elementwise.eq( - ctx, target, source_ir, name + "_eq", L, S - ) - # if I do not squeeze, it will throw the error: condition must be a scalar tensor - condition = impl.squeeze.squeeze( - ctx, target, source_ir, name + "_unsqueeze", need_mask, 0 - ) - if_layer = ctx.net.add_if_conditional() - if_layer.set_condition(condition) - cond_input1 = if_layer.add_input(mm) - cond_input2 = if_layer.add_input(attn_mask) - - true_input = impl.elementwise.add( - ctx, - target, - source_ir, - name + "_attn_bias_add", - cond_input1.get_output(0), - cond_input2.get_output(0), - ) - false_input = cond_input1.get_output(0) - output_layer = if_layer.add_output(true_input, false_input) - scaled_add_attn_bias = output_layer.get_output(0) - elif if_option == "if_conditional_input": - # reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L17 - # if_conditional_input is not working, got the following error: - # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block - # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block) + # generate the mask tensor + tril_tensor = tril( + ctx, target, source_ir, name + "_tril", L, S, sliding_window_size + ) - need_mask = impl.elementwise.eq( - ctx, target, source_ir, name + "_eq", L, S - ) - # if I do not squeeze, it will throw the error: condition must be a scalar tensor - condition = impl.squeeze.squeeze( - ctx, target, source_ir, name + "_unsqueeze", need_mask, 0 - ) - if_layer = ctx.net.add_if_conditional() - if_layer.set_condition(condition) - true_input = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_mask - ) - false_input = mm - true_cond_input = if_layer.add_input(true_input) - false_cond_input = if_layer.add_input(false_input) - output_layer = if_layer.add_output( - true_cond_input.get_output(0), false_cond_input.get_output(0) - ) - scaled_add_attn_bias = output_layer.get_output(0) - elif if_option == "if_conditional": - # reference: https://github.com/pytorch/TensorRT/blob/535c6a8341a3258a9c311406a9af50eb3c68c5a6/examples/dynamo/llm/cache_utils.py#L15-L44 - # if_conditional is not working, got the following error: - # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block - # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) - need_mask = impl.elementwise.eq( - ctx, target, source_ir, name + "_eq", L, S - ) - # if I do not squeeze, it will throw the error: condition must be a scalar tensor - condition = impl.squeeze.squeeze( - ctx, target, source_ir, name + "_unsqueeze", need_mask, 0 - ) - if_layer = ctx.net.add_if_conditional() - if_layer.set_condition(condition) - true_input = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_mask - ) - false_input = mm - output_layer = if_layer.add_output( - true_input.get_output(0), false_input.get_output(0) - ) - scaled_add_attn_bias = output_layer.get_output(0) + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) From e59a5293a2a566973424f418ef27782bbc0ce324 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 28 Aug 2025 17:58:01 -0700 Subject: [PATCH 11/12] test --- .github/workflows/release-windows.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release-windows.yml b/.github/workflows/release-windows.yml index 1a20eb7a71..1dd827ec1c 100644 --- a/.github/workflows/release-windows.yml +++ b/.github/workflows/release-windows.yml @@ -1,6 +1,7 @@ name: Release Windows wheels artifacts on: + pull_request: push: tags: # NOTE: Binary build pipelines should only get triggered on release candidate builds From 6a1a02e9c884b91a736a55456902a39c96c27870 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 29 Aug 2025 15:54:43 -0700 Subject: [PATCH 12/12] add register for different sdpa --- tools/llm/run_llm.py | 5 +- tools/llm/torchtrt_ext/register_sdpa.py | 287 ++++++++++++------------ 2 files changed, 147 insertions(+), 145 deletions(-) diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index cda0b1a96c..075f3ace15 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -58,7 +58,10 @@ def get_model(args): .eval() .cuda() ) - register_sdpa.register_sdpa_pass_with_model_config(model_config=model.config) + if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None: + register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + else: + register_sdpa._SDPA_MAPPING["default"](model_config=model.config) if args.precision == "FP16": model = model.to(torch.float16) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 570b50dd7a..6284dc6d61 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,7 +1,7 @@ import copy import logging import operator -from typing import Callable, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -13,7 +13,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) -from transformers import Gemma3TextConfig +from transformers import AutoConfig, Gemma3TextConfig from .sdpa_converter import * @@ -34,52 +34,130 @@ torch.ops.aten._scaled_dot_product_flash_attention.default, } +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + get_lowering_pass_config, +) + -def register_sdpa_pass_with_model_config(index: int = 0, model_config=None): - """ - Register the SDPA replacement pass with a specific model configuration. +def _process_sdpa_node( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + settings: CompilationSettings, + sliding_window_size: Optional[int] = None, + use_gqa: bool = False, +) -> torch.fx.GraphModule: + """Helper function to process SDPA nodes with common logic.""" + + if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if len(node.args) == 7: + ( + query, + key, + value, + attn_mask, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + if len(node.args) == 6: + ( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + ) = node.args + elif len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + else: + return gm - Args: - model_config: The model configuration object (e.g., from transformers.AutoConfig) - index: Position in the lowering pass list (default: 0) + # Always set causal to True and generate attn_mask inside the sdpa operator + attn_mask = None + is_causal = True + dropout_p = 0.0 - Example: - from transformers import AutoConfig - config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium") - register_sdpa_pass_with_model_config(config) - """ - from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, - _remove_lowering_pass, + logger.warning( + f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, " + f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}" ) - # Create a new pass with the model configuration - @_aten_lowering_pass(index=index, model_config=model_config) - def replace_variants_of_sdpa_with_config( - gm: torch.fx.GraphModule, settings: CompilationSettings - ) -> torch.fx.GraphModule: - """Replace scaled_dot_product_attention with model-specific configuration""" + modified_input_args = ( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + ) - # Access the model configuration from the decorator parameters - from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - get_lowering_pass_config, + # Create a new node with torch.nn.functional.scaled_dot_product_attention + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + "sliding_window_size": sliding_window_size, + }, ) - config = get_lowering_pass_config(replace_variants_of_sdpa_with_config) + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta["val"] = new_node.meta["val"][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) - model_config = config.get("model_config", None) - layer_types = [] + gm.graph.erase_node(node) + return gm + + +def register_gemma3_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + @_aten_lowering_pass(index=index, model_config=model_config) + def gemma3_sdpa_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + """SDPA pass specifically for Gemma3 models with sliding window attention.""" + config = get_lowering_pass_config(gemma3_sdpa_pass) sliding_window = None - # Extract model-specific parameters - if model_config is not None: - if isinstance(model_config, Gemma3TextConfig): - sliding_window = getattr(model_config, "sliding_window", None) - layer_types = getattr(model_config, "layer_types", None) - logger.info(f"Model config: {sliding_window=} {layer_types=}") - else: + layer_types = None + model_config = config.get("model_config", None) + if not isinstance(model_config, Gemma3TextConfig): logger.warning( - "No model configuration provided, using default SDPA replacement behavior" + f"Expected Gemma3TextConfig, got {type(model_config)}, will use default SDPA replacement instead" + ) + else: + sliding_window = getattr(model_config, "sliding_window", None) + layer_types = getattr(model_config, "layer_types", None) + logger.debug( + f"got Gemma3 config: sliding_window={sliding_window}, layer_types={layer_types}" ) + index = 0 for node in gm.graph.nodes: if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: @@ -94,116 +172,37 @@ def replace_variants_of_sdpa_with_config( sliding_window_size = sliding_window index += 1 - if ( - node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ): - if len(node.args) == 7: - ( - query, - key, - value, - attn_mask, - compute_log_sumexp, - dropout_p, - is_causal, - ) = node.args - elif len(node.args) == 5: - query, key, value, attn_mask, is_causal = node.args - dropout_p = 0.0 - - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - elif ( - node.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ): - if len(node.args) == 6: - ( - query, - key, - value, - dropout_p, - is_causal, - return_debug_mask, - ) = node.args - if len(node.args) == 5: - query, key, value, dropout_p, is_causal = node.args - elif len(node.args) == 3: - query, key, value = node.args - dropout_p = 0.0 - is_causal = True - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - - # always set_causal to True and generate attn_mask inside the sdpa operator, do not use the attn_mask from the transformers. - attn_mask = None - is_causal = True - dropout_p = 0.0 - - logger.warning( - f"This current version of SDPA converter only supports {attn_mask=}, {dropout_p=} and {is_causal=} and {sliding_window_size=} configuration. This could cause issues with accuracy for models with different configurations." - ) - modified_input_args = ( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, + # Process the node + logger.debug( + f"Applying Gemma3-specific SDPA replacement with {node.name=}, {node.target=}, {sliding_window_size=}" ) - # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale - with gm.graph.inserting_after(node): - new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, - args=modified_input_args, - kwargs={ - "scale": node.kwargs.get("scale", None), - "use_fp32_acc": settings.use_fp32_acc, - "sliding_window_size": sliding_window_size, - }, - ) - - # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. - new_node.meta = copy.copy(node.meta) - # Check if there's a getitem node following this attention node - for user in list(node.users): - if ( - user.op == "call_function" - and user.target == operator.getitem - ): - # If the getitem is extracting the first element (the output tensor) - if user.args[1] == 0: - # Replace all uses of the getitem with the new attention node - user.replace_all_uses_with(new_node) - new_node.meta["val"] = new_node.meta["val"][0] - # Replace all uses of the original node with the new node - node.replace_all_uses_with(new_node) - - gm.graph.erase_node(node) - - # Clean up the graph + gm = _process_sdpa_node(gm, node, settings, sliding_window_size) + clean_up_graph_after_modifications(gm) + logger.debug("Applied Gemma3-specific SDPA replacement") + return gm - if model_config: - logger.debug( - f"Replaced variants of scaled_dot_product_attention for {getattr(model_config, 'model_type', 'unknown')} model" - ) - else: - logger.debug( - "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" - ) - add_attn_mask_as_output = False - if add_attn_mask_as_output: - add_one_attn_mask_as_output(gm) + +def register_default_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + @_aten_lowering_pass(index=index, model_config=model_config) + def default_sdpa_pass( + gm: torch.fx.GraphModule, + settings: CompilationSettings, + ) -> torch.fx.GraphModule: + """Default SDPA pass for models without specific implementations.""" + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + # Process the node with default logic + gm = _process_sdpa_node(gm, node, settings) + + clean_up_graph_after_modifications(gm) + logger.debug("Applied default SDPA replacement") return gm - logger.info( - f"Registered SDPA pass with model config: {getattr(model_config, 'model_type', 'unknown')}" - ) - return replace_variants_of_sdpa_with_config + +# Global registry for SDPA passes +_SDPA_MAPPING: Dict[str, Callable] = { + "google/gemma-3-1b-it": register_gemma3_sdpa_pass, + "default": register_default_sdpa_pass, +}