diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst index f6da87b145..8e1a8bbf01 100644 --- a/docsrc/tutorials/compile_hf_models.rst +++ b/docsrc/tutorials/compile_hf_models.rst @@ -59,6 +59,10 @@ We have officially verified support for the following LLM families: | Qwen/Qwen2.5-7B-Instruct - FP16, FP32 - Yes + * - Gemma 3 + - | google/gemma-3-1b-it + - FP16, FP32 + - Yes Getting Started with run_llm.py ------------------------------- @@ -185,8 +189,8 @@ The number of key/value cache tensors is equal to the number of attention heads Generating Outputs ------------------- -We use custom `generate `_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching. -There is also a `generate_with_static_cache `_ function that performs autoregressive decoding with KV caching. +We use custom `generate `_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching. +There is also a `generate_with_static_cache `_ function that performs autoregressive decoding with KV caching. The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache. The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``. diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 1fc1b9b420..7f07154eb6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -53,20 +53,28 @@ def _aten_lowering_pass( *args: LoweringPassSignature, index: Optional[int] = None, + **kwargs: Any, ) -> Union[ LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature] ]: """Adds a lowering pass to the registry, at a specified index if desired If no index is specified, the lowering pass is inserted at the end of the list + + Additional keyword arguments can be passed to configure the lowering pass behavior. + These will be stored as metadata on the pass function. """ def add_lowering_pass( lowering_pass: LoweringPassSignature, ) -> LoweringPassSignature: + # Store additional parameters as metadata on the function + if kwargs: + lowering_pass._lowering_pass_config = kwargs + ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) logger.debug( - f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" + f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return lowering_pass @@ -81,7 +89,7 @@ def add_lowering_pass( f"aten_lowering_pass decorator called with invalid arguments {args} " "To specify an index to insert the pass, use the keyword 'index='" ) - # If no arguments are specified, the decorator was called with an index keyword + # If no arguments are specified, the decorator was called with keyword arguments else: return add_lowering_pass @@ -95,6 +103,18 @@ def _remove_lowering_pass(*, index: int) -> None: return +def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]: + """Get the configuration parameters for a lowering pass function + + Args: + lowering_pass: The lowering pass function + + Returns: + Dictionary containing the configuration parameters, or empty dict if none + """ + return getattr(lowering_pass, "_lowering_pass_config", {}) + + def post_lowering( gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings() ) -> torch.fx.GraphModule: diff --git a/tests/py/dynamo/models/test_llm_models.py b/tests/py/dynamo/models/test_llm_models.py new file mode 100644 index 0000000000..188954f68d --- /dev/null +++ b/tests/py/dynamo/models/test_llm_models.py @@ -0,0 +1,60 @@ +import os +import sys + +import pytest +import torch +import torch_tensorrt +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../tools/llm")) +import argparse + +from run_llm import compile_torchtrt +from torchtrt_ext import register_sdpa + + +@pytest.mark.unit +@pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"]) +def test_gemma3_decoder_layer(precision): + + with torch.inference_mode(): + args = argparse.Namespace() + args.debug = False + args.num_tokens = 128 + args.model = "google/gemma-3-1b-it" + args.precision = precision + args.min_block_size = 1 + args.prompt = "What is parallel programming ?" + if args.precision == "FP16": + dtype = torch.float16 + elif args.precision == "BF16": + dtype = torch.bfloat16 + else: + args.precision = "FP32" + dtype = torch.float32 + + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .to("cuda") + ) + + register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + model = model.to(dtype) + # use randint will generate nan values in the logits, use a fixed input_ids for now + # input_ids = torch.randint(0, model.config.vocab_size, (1, args.num_tokens)).to("cuda") + input_ids = torch.tensor([[2, 3689, 563, 10616, 14929, 2360]]).to("cuda") + + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to("cuda") + pyt_outputs = model(input_ids.clone(), position_ids=position_ids.clone()) + trt_model = compile_torchtrt(model, input_ids, args) + trt_outputs = trt_model(input_ids, position_ids=position_ids) + + torch.testing.assert_close( + pyt_outputs.logits, trt_outputs.logits, rtol=5e-1, atol=5e-1 + ) diff --git a/tools/llm/README.md b/tools/llm/README.md index a141505517..00a02ecb7b 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -23,6 +23,7 @@ We have officially verified support for the following models: | LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes | | Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes | | Qwen 3 | Qwen/Qwen3-0.6B
Qwen/Qwen3-1.7B
Qwen/Qwen3-4B
Qwen/Qwen3-8B | FP16, FP32 | Yes | +| Gemma 3 | google/gemma-3-1b-it | FP16, FP32 | Yes | ### Usage diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..ab9470cc61 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -58,6 +58,11 @@ def get_model(args): .eval() .cuda() ) + # register SDPA variant for the model + if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None: + register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + else: + register_sdpa._SDPA_MAPPING["default"](model_config=model.config) if args.precision == "FP16": model = model.to(torch.float16) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 90a00a5798..a650dc1387 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,7 +1,7 @@ import copy import logging import operator -from typing import Callable, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -13,6 +13,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) +from transformers import AutoConfig, Gemma3TextConfig from .sdpa_converter import * @@ -33,94 +34,240 @@ torch.ops.aten._scaled_dot_product_flash_attention.default, } +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + get_lowering_pass_config, +) + -@_aten_lowering_pass -def replace_variants_of_sdpa( - gm: torch.fx.GraphModule, settings: CompilationSettings +def _process_sdpa_node( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + settings: CompilationSettings, + sliding_window_size: Optional[int] = None, + use_gqa: bool = False, ) -> torch.fx.GraphModule: - """Replace scaled_dot_product_attention with an equivalent - implementation which can be accurately converted to TRT """ + Helper function to process SDPA nodes with common logic. + + This function handles the replacement of various scaled dot product attention operations + with the standard torch.nn.functional.scaled_dot_product_attention function. It supports + both efficient attention and flash attention variants, and can handle sliding window + attention for models like Gemma3. + + Args: + gm: The graph module containing the SDPA nodes + node: The specific node to process (must be an SDPA operation) + settings: TensorRT compilation settings + sliding_window_size: Optional sliding window size for models with sliding attention + use_gqa: Whether the model uses Grouped Query Attention + + Returns: + The modified graph module with SDPA nodes replaced + + Raises: + ValueError: If the SDPA node has an unexpected number of arguments + """ + + if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if len(node.args) == 7: + ( + query, + key, + value, + attn_mask, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + if len(node.args) == 6: + ( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + ) = node.args + elif len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + else: + return gm + + # Always set causal to True and generate attn_mask inside the sdpa operator attn_mask = None is_causal = True - for node in gm.graph.nodes: - if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - if ( - node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ): - if len(node.args) == 7: - ( - query, - key, - value, - attn_bias, - compute_log_sumexp, - dropout_p, - is_causal, - ) = node.args - elif len(node.args) == 5: - query, key, value, attn_mask, is_causal = node.args - dropout_p = 0.0 - - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - elif ( - node.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ): - if len(node.args) == 6: - query, key, value, dropout_p, is_causal, return_debug_mask = ( - node.args - ) - if len(node.args) == 5: - query, key, value, dropout_p, is_causal = node.args - elif len(node.args) == 3: - query, key, value = node.args - dropout_p = 0.0 - is_causal = True - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) + dropout_p = 0.0 + + logger.debug( + f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, " + f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}" + ) + + modified_input_args = ( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + ) + + # Create a new node with torch.nn.functional.scaled_dot_product_attention + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + "sliding_window_size": sliding_window_size, + }, + ) + + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta["val"] = new_node.meta["val"][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + return gm + + +def register_gemma3_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + """ + Register SDPA pass for Gemma3 models with sliding window attention. + This function creates and registers a specialized SDPA replacement pass for Gemma3 models. + The pass handles sliding window attention by extracting the sliding_window and layer_types + configuration from the model config and applying appropriate transformations. + + Args: + index: Position in the lowering pass list where this pass should be inserted + model_config: The model configuration object (should be Gemma3TextConfig) + + Example: + from transformers import AutoConfig + config = AutoConfig.from_pretrained("google/gemma-3-1b-it") + register_gemma3_sdpa_pass(index=0, model_config=config) + + Note: + This pass is specifically designed for Gemma3 models and will fall back to + default behavior if the model_config is not a Gemma3TextConfig. + """ + + @_aten_lowering_pass(index=index, model_config=model_config) + def gemma3_sdpa_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + """SDPA pass specifically for Gemma3 models with sliding window attention.""" + config = get_lowering_pass_config(gemma3_sdpa_pass) + sliding_window = None + layer_types = None + model_config = config.get("model_config", None) + if not isinstance(model_config, Gemma3TextConfig): logger.warning( - f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." + f"Expected Gemma3TextConfig, got {type(model_config)}, will use default SDPA replacement instead" + ) + else: + sliding_window = getattr(model_config, "sliding_window", None) + layer_types = getattr(model_config, "layer_types", None) + logger.debug( + f"got Gemma3 config: sliding_window={sliding_window}, layer_types={layer_types}" ) - modified_input_args = (query, key, value, None, dropout_p, True) - # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, is_causal). kwargs has scale - with gm.graph.inserting_after(node): - new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, - args=modified_input_args, - kwargs={ - "scale": node.kwargs.get("scale", None), - "use_fp32_acc": settings.use_fp32_acc, - }, + + index = 0 + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + sliding_window_size = None + if ( + sliding_window is not None + and sliding_window > 0 + and layer_types is not None + and index < len(layer_types) + ): + if layer_types[index] == "sliding_attention": + sliding_window_size = sliding_window + index += 1 + + # Process the node + logger.debug( + f"Applying Gemma3-specific SDPA replacement with {node.name=}, {node.target=}, {sliding_window_size=}" ) + gm = _process_sdpa_node(gm, node, settings, sliding_window_size) - # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. - new_node.meta = copy.copy(node.meta) - # Check if there's a getitem node following this attention node - for user in list(node.users): - if user.op == "call_function" and user.target == operator.getitem: - # If the getitem is extracting the first element (the output tensor) - if user.args[1] == 0: - # Replace all uses of the getitem with the new attention node - user.replace_all_uses_with(new_node) - new_node.meta["val"] = new_node.meta["val"][0] - # Replace all uses of the original node with the new node - node.replace_all_uses_with(new_node) + clean_up_graph_after_modifications(gm) + logger.debug("Applied Gemma3-specific SDPA replacement") + return gm - gm.graph.erase_node(node) - # Clean up the graph - clean_up_graph_after_modifications(gm) +def register_default_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + """ + Register default SDPA pass for models without specific implementations. - logger.debug( - "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" - ) - return gm + This function creates and registers a default SDPA replacement pass that can be used + for any model type. It provides basic SDPA replacement functionality without + model-specific optimizations. + + Args: + index: Position in the lowering pass list where this pass should be inserted + model_config: The model configuration object (optional, for consistency) + + Example: + # Register default pass at index 0 + register_default_sdpa_pass(index=0) + + # Or with model config for consistency + config = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B") + register_default_sdpa_pass(index=0, model_config=config) + + Note: + This is a fallback pass that should be used when no model-specific + SDPA pass is available or when you want generic SDPA replacement behavior. + """ + + @_aten_lowering_pass(index=index, model_config=model_config) + def default_sdpa_pass( + gm: torch.fx.GraphModule, + settings: CompilationSettings, + ) -> torch.fx.GraphModule: + """Default SDPA pass for models without specific implementations.""" + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + # Process the node with default logic + gm = _process_sdpa_node(gm, node, settings) + + clean_up_graph_after_modifications(gm) + logger.debug("Applied default SDPA replacement") + return gm + + +# Global registry for SDPA passes +_SDPA_MAPPING: Dict[str, Callable] = { + "google/gemma-3-1b-it": register_gemma3_sdpa_pass, + "default": register_default_sdpa_pass, +} diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index d05b0379a4..feded31023 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -27,24 +27,110 @@ def tril( name: str, row: TRTTensor, col: TRTTensor, + sliding_window_size: Optional[int] = None, ) -> TRTTensor: + """ + Create a lower triangular mask tensor for attention mechanisms. + + This function generates a lower triangular mask that can be used in attention + operations to enforce causal attention (each position can only attend to itself + and previous positions). It optionally supports sliding window attention by + limiting the attention span to a specified window size. + + The function creates the mask by: + 1. Generating row and column index tensors + 2. Computing the difference between row and column indices + 3. Creating a mask where row >= col (lower triangular) + 4. Optionally applying sliding window constraints + + Args: + ctx: TensorRT conversion context for managing the conversion process + target: Target operation identifier (usually the operation being converted) + source_ir: Source IR type (e.g., ATEN, TRT) - can be None + name: Base name for generated TensorRT operations (will be extended with suffixes) + row: Tensor representing the number of rows (sequence length dimension) + col: Tensor representing the number of columns (sequence length dimension) + sliding_window_size: Optional sliding window size for attention span limitation. + If None, creates a full lower triangular mask. + If specified, creates a sliding window mask where each position + can only attend to positions within the window. + + Returns: + TRTTensor: A boolean mask tensor with shape [batch, heads, seq_len, seq_len] + where True values indicate allowed attention positions. + + Example: + # Create a full lower triangular mask for causal attention + mask = tril(ctx, target, source_ir, "causal_mask", seq_len, seq_len) + + # Create a sliding window mask with window size 3 + mask = tril(ctx, target, source_ir, "sliding_mask", seq_len, seq_len, 3) + + Mask Examples: + Without sliding window (sliding_window_size=None): + For seq_len=5, returns: + [[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [ True, True, True, True, False], + [ True, True, True, True, True]] + + With sliding window (sliding_window_size=3): + For seq_len=5, returns: + [[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True]] + + Note: + This function is specifically designed for attention mechanisms in transformer + models and is used internally by the scaled_dot_product_attention converter. + The sliding window functionality is particularly useful for models like Gemma3 + that use sliding window attention to reduce computational complexity. + """ row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - col_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 ) + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor + ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size + ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) return mask @@ -60,15 +146,19 @@ def scaled_dot_product_attention( kwargs: Dict[str, Any], name: str, ) -> TRTTensor: - # TODO: Handle attn_mask and is_causal arguments in the future - query, key, value, attn_mask, dropout_p, is_causal = args + # always create our own attn_mask + query, key, value, _, dropout_p, is_causal = args # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - is_causal = True + + assert is_causal == True, "is_causal should be set to True" + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) + sliding_window_size = kwargs.get("sliding_window_size", None) + query_dtype = query.dtype if scale is None: @@ -136,7 +226,9 @@ def scaled_dot_product_attention( S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + tril_tensor = tril( + ctx, target, source_ir, name + "_tril", L, S, sliding_window_size + ) temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor @@ -165,11 +257,9 @@ def scaled_dot_product_attention( attn_bias = impl.unary.log( ctx, target, source_ir, name + "_log", one_minus_temp_mask ) - - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias - ) - + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False )