diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4dcb525405..164f0c1065 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -888,6 +888,7 @@ def aten_ops_select( @dynamo_tensorrt_converter( torch.ops.aten.index_put.default, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { @@ -3168,7 +3169,9 @@ def aten_ops_upsample_bicubic2d( @dynamo_tensorrt_converter( - torch.ops.aten.topk.default, capability_validator=topk_validator + torch.ops.aten.topk.default, + capability_validator=topk_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6f4a812dd8..ff743edf27 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -257,15 +257,17 @@ def index( ) else: dim_tensor_shape_mult_d1 = transpose_tensor_shape[i] - mult_d1 = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_shape_{i}", - trt.ElementWiseOperation.PROD, - mult_d1, - dim_tensor_shape_mult_d1, - ) + + if isinstance(dim_tensor_shape_mult_d1, TRTTensor): + mult_d1 = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_shape_{i}", + trt.ElementWiseOperation.PROD, + mult_d1, + dim_tensor_shape_mult_d1, + ) concat_tensor_layer = ctx.net.add_concatenation( [ @@ -548,6 +550,9 @@ def index_put_converter( accumulate: bool = False, ) -> TRTTensor: # Convert 'input_indices' to TRT tensors (or keep None as is) + input_indices = expand_boolean_indices( + ctx, target, source_ir, name, input_tensor, input_indices + ) indices: List[Optional[Union[TRTTensor, None]]] = [] for i, idx in enumerate(input_indices): if idx is None: @@ -571,13 +576,31 @@ def index_put_converter( K = len(I) # Determine the maximum size 'N' among the index tensors if K > 0: - index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None] + index_shapes = ( + [] + ) # [tensor.shape[0] for tensor in indices if tensor is not None] + for idx_tensor in indices: + if idx_tensor is not None: + if idx_tensor.shape[0] != DYNAMIC_DIM: + index_shapes.append(idx_tensor.shape[0]) + else: + index_shapes.append( + get_shape( + ctx, + target, + source_ir, + name + "idx_shape_dim_0", + idx_tensor, + 0, + ) + ) N = max(index_shapes) if index_shapes else 1 else: N = 1 # Compute shapes and volume for the free dimensions F_shapes = [input_tensor.shape[i] for i in F] + assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported" F_volume = trt.volume(F_shapes) if F_shapes else 1 # Process indexed dimensions (I) @@ -585,8 +608,8 @@ def index_put_converter( for i in I: idx = indices[i] assert idx is not None - idx_reshaped = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_idx_I_{i}", idx, (idx.shape[0], 1) + idx_reshaped = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, f"{name}_unsqueeze_idx_I_{i}", idx, 1 ) expanded_idx = impl.slice.expand( ctx, @@ -608,46 +631,50 @@ def index_put_converter( ) arange_tensors.append(arange_tensor) - meshgrid_tensors = [] - for i, arange in enumerate(arange_tensors): - reshape_shape = [1] * len(F) - reshape_shape[i] = F_shapes[i] - arange_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_arange_F_{F[i]}", - arange, - tuple(reshape_shape), - ) - expanded_arange = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_arange_F_{F[i]}", - arange_reshaped, - tuple(F_shapes), - ) - meshgrid_tensors.append(expanded_arange) - - meshgrid_stacked = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_stack_meshgrid", - [ - impl.shuffle.reshape( + if len(arange_tensors) == 1: + # No need to stack + meshgrid_stacked = arange_tensors[0] + else: + meshgrid_tensors = [] + for i, arange in enumerate(arange_tensors): + reshape_shape = [1] * len(F) + reshape_shape[i] = F_shapes[i] + arange_reshaped = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_reshape_mesh_{i}", - t, - (*F_shapes, 1), + f"{name}_reshape_arange_F_{F[i]}", + arange, + tuple(reshape_shape), ) - for i, t in enumerate(meshgrid_tensors) - ], - dim=-1, - ) + expanded_arange = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_arange_F_{F[i]}", + arange_reshaped, + tuple(F_shapes), + ) + meshgrid_tensors.append(expanded_arange) + + meshgrid_stacked = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_stack_meshgrid", + [ + impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_mesh_{i}", + t, + (*F_shapes, 1), + ) + for i, t in enumerate(meshgrid_tensors) + ], + dim=-1, + ) meshgrid_reshaped = impl.shuffle.reshape( ctx, target, @@ -672,21 +699,15 @@ def index_put_converter( # Combine all indexed dimensions (I) if K > 0: - I_combined = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_cat_I", - [ - impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1) - ) - for i, t in enumerate(I_tensors) - ], - dim=2, - ) + + I_combined = [ + impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1) + ) + for i, t in enumerate(I_tensors) + ] else: - I_combined = None + I_combined = [] # Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded ii_list = [] @@ -695,24 +716,12 @@ def index_put_converter( for dim in range(rank): unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}" if dim in I: - start = [0, 0, i_idx] - shape = [N, F_volume, 1] - stride = [1, 1, 1] - idx_tensor = impl.slice.slice( - ctx, - target, - source_ir, - f"{name}_slice_I_dim_{unique_suffix}", - I_combined, - start, - shape, - stride, - ) + idx_tensor = I_combined[i_idx] ii_list.append(idx_tensor) i_idx += 1 else: start = [0, 0, f_idx] - shape = [N, F_volume, 1] + shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1] stride = [1, 1, 1] mesh_tensor = impl.slice.slice( ctx, @@ -731,20 +740,24 @@ def index_put_converter( indices_cat = impl.cat.cat( ctx, target, source_ir, f"{name}_cat_indices", ii_list, dim=2 ) + + # Flatten the indices_cat to (N * F_volume, rank) indices_cat = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_indices_cat", indices_cat, - (N * F_volume, rank), + (-1, rank), ) if not isinstance(values, TRTTensor): values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0) # Define the expected shape based on (N,) + F_shapes - expected_shape = (N,) + tuple(F_shapes) + expected_shape = ( + (-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes) + ) # Broadcast 'values' to match the expected shape if len(values.shape) == 0 or values.shape == (1,): # Scalar case @@ -761,7 +774,12 @@ def index_put_converter( ) else: # Non-scalar case values_shape = list(values.shape) - if K > 0 and N in values_shape: + if ( + K > 0 + and N in values_shape + and (len(F) > 1 and max(F) - min(F) + 1 == len(F)) + ): + # Continuous case n_idx = values_shape.index(N) permute_order = [n_idx] + [ i for i in range(len(values_shape)) if i != n_idx @@ -807,31 +825,27 @@ def index_put_converter( tuple(broadcast_shape), ) else: + # Discontinuous case values_shape_padded = [1] * ( len(expected_shape) - len(values.shape) ) + list(values.shape) broadcast_shape = [] for exp_dim, val_dim in zip(expected_shape, values_shape_padded): - if val_dim == 1 or exp_dim == val_dim: + if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM: + broadcast_shape.append(-1) + elif val_dim == 1 or exp_dim == val_dim: broadcast_shape.append(exp_dim) else: raise ValueError( f"Cannot broadcast {values.shape} to {expected_shape}" ) - values_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_values", - values, - tuple(broadcast_shape), - ) + values_expanded = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_values", - values_reshaped, + values, expected_shape, ) @@ -842,16 +856,51 @@ def index_put_converter( source_ir, f"{name}_flatten_values", values_expanded, - (N * F_volume,), + (-1,), ) - indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32") - # Perform Scatter ND operation - scatter_layer = ctx.net.add_scatter( - input_tensor, - indices_cat, - flattened_values, - trt.ScatterMode.ND if not accumulate else trt.ScatterMode.ND_ELEMENTWISE_ADD, - ) - set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) - return scatter_layer.get_output(0) + if accumulate: + zero_tensor = impl.full.full( + ctx, + target, + source_ir, + f"{name}_zero_tensor", + [ + get_shape( + ctx, + target, + source_ir, + name + f"input_tensor_shape_dim_{i}", + input_tensor, + i, + ) + for i in range(len(input_tensor.shape)) + ], + 0.0, + dtype=input_tensor.dtype, + ) + # Perform Scatter ND operation + scatter_layer = ctx.net.add_scatter( + zero_tensor, + indices_cat, + flattened_values, + trt.ScatterMode.ND, + ) + set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) + + scatter_out = scatter_layer.get_output(0) + result = impl.elementwise.add( + ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor + ) + return result + + else: + scatter_layer = ctx.net.add_scatter( + input_tensor, + indices_cat, + flattened_values, + trt.ScatterMode.ND, + ) + set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) + scatter_out = scatter_layer.get_output(0) + return scatter_out diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py index 2a2c8e9d5e..a9b7c48ec2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py @@ -23,7 +23,8 @@ def remove_num_users_is_0_nodes( and len(node.all_input_nodes) > 0 ): gm.graph.erase_node(node) - gm = clean_up_graph_after_modifications(gm) + + gm = clean_up_graph_after_modifications(gm) logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}") diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index 74e38cd0c5..0f4da97d89 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -1,4 +1,5 @@ import torch +import torch_tensorrt as torchtrt from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests @@ -194,11 +195,43 @@ class TestIndexPutConverter(DispatchTestCase): dtype=torch.int32, ), ), + # param( + # test_name="4d_indices_none_none_multiple_idx_broadcast_error", + # source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32), + # indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), + # value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), + # ), + param( + test_name="discontinuous_test", + source_tensor=torch.zeros([2, 4, 4], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 0, 1], dtype=torch.int64), + None, + torch.tensor([0, 0, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), + ), + param( + test_name="discontinuous_test_two", + source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32), + indices_tensor=( + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), + ), param( - test_name="4d_indices_none_none_multiple_idx_broadcast_error", - source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32), - indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), - value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), + test_name="continuous_test", + source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32), + indices_tensor=( + None, + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), ), # param( # test_name="2d_indices_accumulate_True", @@ -244,6 +277,94 @@ def forward(self, source_tensor, value_tensor): use_dynamo_tracer=True, ) + def test_index_add_dynamic_shape(self): + + class Model(torch.nn.Module): + def forward(self, x, y, z, a, b): + x.index_add_(0, y, z) + x.index_add_(0, a, b) + return x + + dim = 10 + model = Model().cuda() + inputs = [ + torch.ones((12, dim)).half().cuda(), + torch.tensor([0, 1]).cuda(), + torch.randn((2, dim)).half().cuda(), + torch.tensor([2, 9, 11]).cuda(), + torch.randn((3, dim)).half().cuda(), + ] + torch_output = model.cuda().forward(*inputs) + seq_len1 = torch.export.Dim("seq_len1", min=1, max=128) + seq_len2 = torch.export.Dim("seq_len2", min=1, max=128) + seq_len3 = torch.export.Dim("seq_len3", min=1, max=128) + + ep = torch.export.export( + model, + tuple(inputs), + dynamic_shapes=( + {0: seq_len1}, + {0: seq_len2}, + {0: seq_len2}, + {0: seq_len3}, + {0: seq_len3}, + ), + ) + with torchtrt.dynamo.Debugger( + log_level="debug", + capture_fx_graph_after=["remove_num_users_is_0_nodes"], + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, + ): + trt_mod = torchtrt.dynamo.compile( + ep, + inputs, + enabled_precisions={torch.float16}, + min_block_size=1, + use_explicit_typing=False, + use_fp32_acc=False, + disable_tf32=True, + ) + result = trt_mod(*inputs) + assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + + def test_bool_mask_test(self): + + source_tensor = torch.ones([5, 10], dtype=torch.float32).cuda() + indices_tensor = torch.tensor([False, False, True, False, True]) + value_tensor = torch.zeros([2, 10], dtype=torch.float32).cuda() + + dim1 = torch.export.Dim("dim1", min=1, max=5) + dim2 = torch.export.Dim("dim2", min=1, max=5) + + class TestIndexPut(torch.nn.Module): + def forward(self, source_tensor, indices_tensor, value_tensor): + source_tensor[indices_tensor] = value_tensor + return source_tensor + + model = TestIndexPut() + torch_output = model.forward(source_tensor, indices_tensor, value_tensor) + + ep = torch.export.export( + model, + (source_tensor, indices_tensor, value_tensor), + dynamic_shapes=({0: dim1}, {0: dim1}, {0: dim2}), + ) + with torchtrt.dynamo.Debugger(log_level="debug"): + trt_engine = torchtrt.dynamo.compile( + ep, + inputs=(source_tensor, indices_tensor, value_tensor), + enabled_precisions={torch.float32}, + min_block_size=1, + use_explicit_typing=False, + use_fp32_acc=False, + disable_tf32=True, + use_python_runtime=True, + ) + result = trt_engine(source_tensor, indices_tensor, value_tensor) + + torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + if __name__ == "__main__": run_tests() diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index ab9470cc61..97b6616581 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -71,7 +71,7 @@ def get_model(args): else: model = model.to(torch.float32) - return model + return model.cuda() def compile_torchtrt(model, input_ids, args):