diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index abf721198b..178caa17c1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -414,13 +414,19 @@ def index_dtype_validator( for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype not in (torch.int32, torch.int64): + if val is not None and val.dtype not in ( + torch.int32, + torch.int64, + torch.bool, + ): return False return True @dynamo_tensorrt_converter( - torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator + torch.ops.aten.index.Tensor, + capability_validator=index_dtype_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index c4d44a07ea..b657c19a88 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -51,6 +51,67 @@ def select( return layer.get_output(0) +def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: + if isinstance(tensor, (TRTTensor)): + if getattr(tensor, "meta", None) is None: + return tensor.dtype == trt.DataType.BOOL + val = tensor.meta.get("val") + if val is not None and val.dtype is torch.bool: + return True + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool + + +def expand_boolean_indices( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + for i, ind in enumerate(indices): + if ind is not None and is_boolean_tensor(ind): + _LOGGER.debug( + f"Boolean index detected at position {i}, converting with nonzero()" + ) + + mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") + + nonzero_layer = ctx.net.add_non_zero(mask_tensor) + set_layer_name( + nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir + ) + nonzero_indices = nonzero_layer.get_output(0) + + # nonzero returns shape [N, dims], we need to extract dim i + if len(indices) == 1: + # x[mask] — 1D mask + squeeze_layer = ctx.net.add_shuffle(nonzero_indices) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_nonzero_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + ind = squeezed_index + else: + # Advanced multi-axis mask: extract index i from shape [N, D] + gather_axis = 1 # dim index + gather_layer = ctx.net.add_gather( + nonzero_indices, + get_trt_tensor(ctx, i, name + f"_dim_index_{i}"), + gather_axis, + ) + set_layer_name( + gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir + ) + extracted_index = gather_layer.get_output(0) + ind = extracted_index + return indices + + def index( ctx: ConversionContext, target: Target, @@ -61,8 +122,6 @@ def index( ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # check if the input is dynamic - dynamic_shape = has_dynamic_shape(input.shape) # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( @@ -76,6 +135,7 @@ def index( # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 8e21f945dc..fc4a70b1ff 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -168,7 +168,31 @@ def forward(self, input): dtype=torch.float32, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs) + self.run_test_with_dynamic_shape( + TestModule(), input_specs, use_dynamo_tracer=True + ) + + +class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase): + def test_index_input_non_dynamic_index_dynamic(self): + class TestIndexWithRuntimeIndex(torch.nn.Module): + def forward(self, x): + mask = x > 0 + idx = torch.nonzero(mask, as_tuple=True) + return torch.ops.aten.index.Tensor(x, idx) + + input_specs = [ + Input( + min_shape=(2, 2), + opt_shape=(2, 2), + max_shape=(8, 8), + dtype=torch.float32, + ), + ] + # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True + self.run_test_with_dynamic_shape( + TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True + ) if __name__ == "__main__": diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..5647a10a7a 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -116,7 +116,7 @@ def compile_torchtrt(model, input_ids, args): use_fp32_acc=use_fp32_acc, device=DEVICE, disable_tf32=True, - use_python_runtime=True, + use_python_runtime=False, debug=args.debug, offload_module_to_cpu=True, min_block_size=args.min_block_size, diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 90a00a5798..bf63801276 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -86,10 +86,7 @@ def replace_variants_of_sdpa( f"Unexpected number of arguments for {node.target} in the graph" ) - logger.warning( - f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." - ) - modified_input_args = (query, key, value, None, dropout_p, True) + modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 47083c7b48..03a960edc4 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -27,24 +27,51 @@ def tril( name: str, row: TRTTensor, col: TRTTensor, + sliding_window_size: Optional[int] = None, ) -> TRTTensor: + row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - col_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 ) - - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor + ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask + + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask + ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) return mask @@ -66,7 +93,7 @@ def scaled_dot_product_attention( # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - is_causal = True + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) query_dtype = query.dtype @@ -134,37 +161,43 @@ def scaled_dot_product_attention( L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2) if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) - - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) + if is_causal: + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, + temp_mask, + query_dtype, + name + "_casted_bool", + target, + source_ir, + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + else: + attn_bias = attn_mask scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..c56aa9b490 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok num_tokens_generated = 0 kv_cache = get_zeroed_dynamic_cache_inputs(model) last_position_id = position_ids[-1, -1].item() - breakpoint() while num_tokens_generated < num_output_tokens: is_generate = False if input_seq.shape[1] > 1 else True position_ids = (