diff --git a/.github/workflows/release-windows.yml b/.github/workflows/release-windows.yml index 1a20eb7a71..1dd827ec1c 100644 --- a/.github/workflows/release-windows.yml +++ b/.github/workflows/release-windows.yml @@ -1,6 +1,7 @@ name: Release Windows wheels artifacts on: + pull_request: push: tags: # NOTE: Binary build pipelines should only get triggered on release candidate builds diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7b91eec34..c12918fdc9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,8 @@ repos: hooks: - id: ruff - repo: https://github.com/psf/black - rev: 25.1.0 + # pin to a lower version for py3.9 compatibility + rev: 23.12.1 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4482a00f79..4dcb525405 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -415,7 +415,11 @@ def index_dtype_validator( for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype not in (torch.int32, torch.int64): + if val is not None and val.dtype not in ( + torch.int32, + torch.int64, + torch.bool, + ): return False return True @@ -424,6 +428,7 @@ def index_dtype_validator( torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator, supports_dynamic_shapes=True, + requires_output_allocator=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index c4d44a07ea..ded50519ad 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -14,7 +14,6 @@ cast_trt_tensor, get_positive_dim, get_trt_tensor, - has_dynamic_shape, set_layer_name, to_numpy, ) @@ -51,6 +50,71 @@ def select( return layer.get_output(0) +def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: + if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)): + return bool(tensor.dtype == torch.bool) + # when index is a node + else: + val = tensor.meta.get("val") + if val is not None and val.dtype is torch.bool: + return True + + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool + + +def expand_boolean_indices( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + new_indices = [] + for i, ind in enumerate(indices): + if ind is not None and is_boolean_tensor(ind): + _LOGGER.debug( + f"Boolean index detected at position {i}, converting with nonzero()" + ) + mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") + + nonzero_layer = ctx.net.add_non_zero(mask_tensor) + set_layer_name( + nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir + ) + nonzero_indices = nonzero_layer.get_output(0) + + # nonzero returns shape [N, dims], we need to extract dim i + if len(indices) == 1: + # x[mask] — 1D mask + to_squeeze = nonzero_indices + else: + # Advanced multi-axis mask: extract index i from shape [N, D] + gather_axis = 1 # dim index + gather_layer = ctx.net.add_gather( + nonzero_indices, + get_trt_tensor(ctx, i, name + f"_dim_index_{i}"), + gather_axis, + ) + set_layer_name( + gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir + ) + to_squeeze = gather_layer.get_output(0) + squeeze_layer = ctx.net.add_shuffle(to_squeeze) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_mask_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + new_indices.append(squeezed_index) + else: + new_indices.append(ind) + return new_indices + + def index( ctx: ConversionContext, target: Target, @@ -61,13 +125,12 @@ def index( ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # check if the input is dynamic - dynamic_shape = has_dynamic_shape(input.shape) # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( "Determining whether aten.index constant-index optimization can be invoked" ) + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) is_numpy = all( isinstance(ind, (torch.Tensor, np.ndarray)) for ind in indices diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 1fc1b9b420..7f07154eb6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -53,20 +53,28 @@ def _aten_lowering_pass( *args: LoweringPassSignature, index: Optional[int] = None, + **kwargs: Any, ) -> Union[ LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature] ]: """Adds a lowering pass to the registry, at a specified index if desired If no index is specified, the lowering pass is inserted at the end of the list + + Additional keyword arguments can be passed to configure the lowering pass behavior. + These will be stored as metadata on the pass function. """ def add_lowering_pass( lowering_pass: LoweringPassSignature, ) -> LoweringPassSignature: + # Store additional parameters as metadata on the function + if kwargs: + lowering_pass._lowering_pass_config = kwargs + ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) logger.debug( - f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" + f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return lowering_pass @@ -81,7 +89,7 @@ def add_lowering_pass( f"aten_lowering_pass decorator called with invalid arguments {args} " "To specify an index to insert the pass, use the keyword 'index='" ) - # If no arguments are specified, the decorator was called with an index keyword + # If no arguments are specified, the decorator was called with keyword arguments else: return add_lowering_pass @@ -95,6 +103,18 @@ def _remove_lowering_pass(*, index: int) -> None: return +def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]: + """Get the configuration parameters for a lowering pass function + + Args: + lowering_pass: The lowering pass function + + Returns: + Dictionary containing the configuration parameters, or empty dict if none + """ + return getattr(lowering_pass, "_lowering_pass_config", {}) + + def post_lowering( gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings() ) -> torch.fx.GraphModule: diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 8e21f945dc..e069fab263 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -71,6 +71,27 @@ class TestIndexConstantConverter(DispatchTestCase): [None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])], torch.randn(2, 4, 4, 2), ), + ( + "mask_index_three_dim", + [None, torch.tensor([True, False]), None], + torch.randn(2, 2, 2), + ), + ( + "mask_index_two_dim", + [torch.tensor([True, False])], + torch.randn(2, 2), + ), + ( + # covers multi axis and discontinuous indices + "mask_index_multi_axis", + [ + None, + torch.tensor([True, False]), # axis 1 + None, + torch.tensor([True, False]), # axis 3 + ], + torch.randn(2, 4, 4, 2), + ), ] ) def test_index_constant(self, _, index, input): @@ -168,7 +189,31 @@ def forward(self, input): dtype=torch.float32, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs) + self.run_test_with_dynamic_shape( + TestModule(), input_specs, use_dynamo_tracer=True + ) + + +class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase): + def test_index_input_non_dynamic_index_dynamic(self): + class TestIndexWithRuntimeIndex(torch.nn.Module): + def forward(self, x): + mask = x > 0 + idx = torch.nonzero(mask, as_tuple=True) + return torch.ops.aten.index.Tensor(x, idx) + + input_specs = [ + Input( + min_shape=(2, 2), + opt_shape=(2, 2), + max_shape=(8, 8), + dtype=torch.float32, + ), + ] + # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True + self.run_test_with_dynamic_shape( + TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True + ) if __name__ == "__main__": diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..075f3ace15 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -58,6 +58,10 @@ def get_model(args): .eval() .cuda() ) + if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None: + register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + else: + register_sdpa._SDPA_MAPPING["default"](model_config=model.config) if args.precision == "FP16": model = model.to(torch.float16) diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py index 4634b79a52..6d6c4409e9 100644 --- a/tools/llm/static_cache_v2.py +++ b/tools/llm/static_cache_v2.py @@ -254,7 +254,7 @@ def insert_kv_slicing_before_sdpa( sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + ( attn_mask, dropout_p, - True, + is_causal, ) # kv_cache_for_graph.extend([k_node, v_node]) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 90a00a5798..6284dc6d61 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,7 +1,7 @@ import copy import logging import operator -from typing import Callable, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -13,6 +13,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) +from transformers import AutoConfig, Gemma3TextConfig from .sdpa_converter import * @@ -33,94 +34,175 @@ torch.ops.aten._scaled_dot_product_flash_attention.default, } +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + get_lowering_pass_config, +) -@_aten_lowering_pass -def replace_variants_of_sdpa( - gm: torch.fx.GraphModule, settings: CompilationSettings + +def _process_sdpa_node( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + settings: CompilationSettings, + sliding_window_size: Optional[int] = None, + use_gqa: bool = False, ) -> torch.fx.GraphModule: - """Replace scaled_dot_product_attention with an equivalent - implementation which can be accurately converted to TRT - """ + """Helper function to process SDPA nodes with common logic.""" + + if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if len(node.args) == 7: + ( + query, + key, + value, + attn_mask, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + if len(node.args) == 6: + ( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + ) = node.args + elif len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + else: + return gm + + # Always set causal to True and generate attn_mask inside the sdpa operator attn_mask = None is_causal = True - for node in gm.graph.nodes: - if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - if ( - node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ): - if len(node.args) == 7: - ( - query, - key, - value, - attn_bias, - compute_log_sumexp, - dropout_p, - is_causal, - ) = node.args - elif len(node.args) == 5: - query, key, value, attn_mask, is_causal = node.args - dropout_p = 0.0 - - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - elif ( - node.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ): - if len(node.args) == 6: - query, key, value, dropout_p, is_causal, return_debug_mask = ( - node.args - ) - if len(node.args) == 5: - query, key, value, dropout_p, is_causal = node.args - elif len(node.args) == 3: - query, key, value = node.args - dropout_p = 0.0 - is_causal = True - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) + dropout_p = 0.0 + + logger.warning( + f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, " + f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}" + ) + + modified_input_args = ( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + ) + + # Create a new node with torch.nn.functional.scaled_dot_product_attention + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + "sliding_window_size": sliding_window_size, + }, + ) + + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta["val"] = new_node.meta["val"][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + return gm + +def register_gemma3_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + @_aten_lowering_pass(index=index, model_config=model_config) + def gemma3_sdpa_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + """SDPA pass specifically for Gemma3 models with sliding window attention.""" + config = get_lowering_pass_config(gemma3_sdpa_pass) + sliding_window = None + layer_types = None + model_config = config.get("model_config", None) + if not isinstance(model_config, Gemma3TextConfig): logger.warning( - f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." + f"Expected Gemma3TextConfig, got {type(model_config)}, will use default SDPA replacement instead" ) - modified_input_args = (query, key, value, None, dropout_p, True) - # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, is_causal). kwargs has scale - with gm.graph.inserting_after(node): - new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, - args=modified_input_args, - kwargs={ - "scale": node.kwargs.get("scale", None), - "use_fp32_acc": settings.use_fp32_acc, - }, + else: + sliding_window = getattr(model_config, "sliding_window", None) + layer_types = getattr(model_config, "layer_types", None) + logger.debug( + f"got Gemma3 config: sliding_window={sliding_window}, layer_types={layer_types}" + ) + + index = 0 + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + sliding_window_size = None + if ( + sliding_window is not None + and sliding_window > 0 + and layer_types is not None + and index < len(layer_types) + ): + if layer_types[index] == "sliding_attention": + sliding_window_size = sliding_window + index += 1 + + # Process the node + logger.debug( + f"Applying Gemma3-specific SDPA replacement with {node.name=}, {node.target=}, {sliding_window_size=}" ) + gm = _process_sdpa_node(gm, node, settings, sliding_window_size) - # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. - new_node.meta = copy.copy(node.meta) - # Check if there's a getitem node following this attention node - for user in list(node.users): - if user.op == "call_function" and user.target == operator.getitem: - # If the getitem is extracting the first element (the output tensor) - if user.args[1] == 0: - # Replace all uses of the getitem with the new attention node - user.replace_all_uses_with(new_node) - new_node.meta["val"] = new_node.meta["val"][0] - # Replace all uses of the original node with the new node - node.replace_all_uses_with(new_node) - - gm.graph.erase_node(node) - - # Clean up the graph - clean_up_graph_after_modifications(gm) - - logger.debug( - "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" - ) - return gm + clean_up_graph_after_modifications(gm) + logger.debug("Applied Gemma3-specific SDPA replacement") + return gm + + +def register_default_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + @_aten_lowering_pass(index=index, model_config=model_config) + def default_sdpa_pass( + gm: torch.fx.GraphModule, + settings: CompilationSettings, + ) -> torch.fx.GraphModule: + """Default SDPA pass for models without specific implementations.""" + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + # Process the node with default logic + gm = _process_sdpa_node(gm, node, settings) + + clean_up_graph_after_modifications(gm) + logger.debug("Applied default SDPA replacement") + return gm + + +# Global registry for SDPA passes +_SDPA_MAPPING: Dict[str, Callable] = { + "google/gemma-3-1b-it": register_gemma3_sdpa_pass, + "default": register_default_sdpa_pass, +} diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index d05b0379a4..f7a7203f38 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -27,24 +27,50 @@ def tril( name: str, row: TRTTensor, col: TRTTensor, + sliding_window_size: Optional[int] = None, ) -> TRTTensor: row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - col_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 ) - - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask + + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size + ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask + ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) return mask @@ -66,9 +92,13 @@ def scaled_dot_product_attention( # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - is_causal = True + + assert is_causal == True, "is_causal should be set to True" + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) + sliding_window_size = kwargs.get("sliding_window_size", None) + query_dtype = query.dtype if scale is None: @@ -136,7 +166,9 @@ def scaled_dot_product_attention( S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + tril_tensor = tril( + ctx, target, source_ir, name + "_tril", L, S, sliding_window_size + ) temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor @@ -165,11 +197,9 @@ def scaled_dot_product_attention( attn_bias = impl.unary.log( ctx, target, source_ir, name + "_log", one_minus_temp_mask ) - - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias - ) - + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..c56aa9b490 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok num_tokens_generated = 0 kv_cache = get_zeroed_dynamic_cache_inputs(model) last_position_id = position_ids[-1, -1].item() - breakpoint() while num_tokens_generated < num_output_tokens: is_generate = False if input_seq.shape[1] > 1 else True position_ids = (