diff --git a/examples/dynamo/llama2_flashinfer_rmsnorm.py b/examples/dynamo/llama2_flashinfer_rmsnorm.py index 7542a9a1b7..a5e7fc672f 100644 --- a/examples/dynamo/llama2_flashinfer_rmsnorm.py +++ b/examples/dynamo/llama2_flashinfer_rmsnorm.py @@ -15,11 +15,12 @@ This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization. """ -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import flashinfer import torch import torch_tensorrt +from torch._subclasses import FakeTensor from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( _aten_lowering_pass, @@ -51,6 +52,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso def replace_rmsnorm( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] ) -> torch.fx.GraphModule: + print("before2\n") + print(gm.graph) for node in gm.graph.nodes: if ( node.target == torch.ops.aten._to_copy.default @@ -90,13 +93,60 @@ def replace_rmsnorm( weight_mul_node = list(copy_node.users)[0] weight = weight_mul_node.args[0] + hidden_states_node = node.args[0] - original_meta = weight_mul_node.meta.get( + original_meta = hidden_states_node.meta.get( "tensor_meta", {} ) memory_format = original_meta.memory_format + from torch.fx.experimental.symbolic_shapes import ( + ShapeEnv, + ) + + shape_env = ShapeEnv() with gm.graph.inserting_after(weight_mul_node): + input_meta = node.args[0].meta["val"] + batch_size = input_meta.shape[0] + seq_len = input_meta.shape[1] + head_dim = input_meta.shape[2] + + # Create symbolic ints for batch_size + if isinstance(batch_size, int): + batch_size_unbacked_symint = ( + shape_env.create_unbacked_symint() + ) + torch._check( + batch_size_unbacked_symint >= batch_size + ) + torch._check( + batch_size_unbacked_symint <= batch_size + ) + elif isinstance(batch_size, torch.SymInt): + pass + else: + raise ValueError( + "Batch size must be a sym int" + ) + + # Create symbolic ints for head_dim + if isinstance(head_dim, int): + head_dim_unbacked_symint = ( + shape_env.create_unbacked_symint() + ) + torch._check( + head_dim_unbacked_symint >= head_dim + ) + torch._check( + head_dim_unbacked_symint <= head_dim + ) + elif isinstance(head_dim, torch.SymInt): + pass + else: + raise ValueError( + "head_dim must be a sym int" + ) + b = gm.graph.create_node( op="call_function", target=torch.ops.aten.sym_size.int, @@ -111,19 +161,24 @@ def replace_rmsnorm( is_quantized=False, qparams={}, ) + + batch_size = node.args[0].meta["val"].shape[0] + b.meta["val"] = batch_size_unbacked_symint + s = gm.graph.create_node( op="call_function", target=torch.ops.aten.sym_size.int, args=(node.args[0], 1), ) s.meta.update(b.meta) - + s.meta["val"] = seq_len d = gm.graph.create_node( op="call_function", target=torch.ops.aten.sym_size.int, args=(node.args[0], 2), ) d.meta.update(b.meta) + d.meta["val"] = head_dim_unbacked_symint with gm.graph.inserting_after(b): new_first_dim = gm.graph.create_node( @@ -150,11 +205,11 @@ def replace_rmsnorm( [b_val * s_val, d_val] ), dtype=original_meta.dtype, - requires_grad=True, stride=None, memory_format=memory_format, is_quantized=False, qparams={}, + requires_grad=False, ) ) @@ -183,11 +238,22 @@ def replace_rmsnorm( [b, s, d], ), ) + reshapback_node.meta["tensor_meta"] = ( + TensorMetadata( + shape=torch.Size([b_val, s_val, d_val]), + dtype=original_meta.dtype, + stride=None, + memory_format=memory_format, + is_quantized=False, + qparams={}, + requires_grad=False, + ) + ) + # reshapback_node.meta.update(weight_mul_node.meta) weight_mul_node.replace_all_uses_with( reshapback_node ) - reshapback_node.meta.update(weight_mul_node.meta) modified_graph = True @@ -207,6 +273,43 @@ def replace_rmsnorm( return gm +@_aten_lowering_pass +def set_copy_node_meta_data( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.target == torch.ops.aten._to_copy.default and ( + "tensor_meta" not in node.meta + ): + input_node = node.args[0] + + # Check if input has metadata + if "tensor_meta" in input_node.meta: + # Copy input metadata and update dtype to float32 + output_meta = input_node.meta["tensor_meta"] + # output_meta.dtype = node.kwargs.get("dtype") + + # # Assign to the _to_copy node + # node.meta["tensor_meta"] = output_meta + node.meta["tensor_meta"] = TensorMetadata( + shape=output_meta.shape, + dtype=node.kwargs.get("dtype"), + requires_grad=True, + stride=None, + memory_format=input_node.meta["tensor_meta"].memory_format, + is_quantized=False, + qparams={}, + ) + + else: + # Handle missing metadata (optional warning/logging) + print(f"Warning: Input node {input_node} has no tensor_meta") + + gm = clean_up_graph_after_modifications(gm) + + return gm + + # 1. Create a custom config with 1 layer config = LlamaConfig( vocab_size=32000, @@ -222,12 +325,14 @@ def replace_rmsnorm( with torch.no_grad(): model = LlamaForCausalLM(config).eval().half() +MAX_TOKENS = 64 +seq_len = torch.export.Dim("seq_len", min=2, max=MAX_TOKENS) # 3. Export with static shapes input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64] exported = torch.export.export( model, (input_ids,), - dynamic_shapes=None, # Fully static + dynamic_shapes=({1: seq_len},), ) # Test forward pass @@ -238,20 +343,61 @@ def replace_rmsnorm( # Export validation DEVICE = torch.device("cuda:0") - -with torch_tensorrt.logging.errors(): - trt_model = torch_tensorrt.dynamo.compile( - exported, - inputs=[input_ids], - enabled_precisions={torch.float32, torch.float16}, - truncate_double=True, - device=DEVICE, - disable_tf32=True, - use_explicit_typing=False, - use_fp32_acc=True, - ) - -input_ids = input_ids.to(DEVICE) - -res = trt_model.forward(input_ids) -print(res) +stream = torch.cuda.Stream() +with torch.cuda.stream(stream): + with torch_tensorrt.dynamo.Debugger( + log_level="info", + # profile_format="trex", + # save_engine_profile=True, + capture_fx_graph_before=["remove_detach"], + capture_fx_graph_after=["replace_rmsnorm"], + logging_dir="/home/profile/logging/torchtrt", + engine_builder_monitor=False, + ): + trt_model = torch_tensorrt.dynamo.compile( + exported, + inputs=[input_ids], + enabled_precisions={torch.float32, torch.float16}, + truncate_double=True, + device=DEVICE, + disable_tf32=True, + use_explicit_typing=False, + use_fp32_acc=True, + use_python_runtime=True, + ) + + input_ids = input_ids.to(DEVICE) + + res = trt_model.forward(input_ids) + + # Benchmark TensorRT models + + import time + + def benchmark_model(model, input_ids, label, n_runs=100): + torch.cuda.synchronize() + start = time.time() + for _ in range(n_runs): + with torch.no_grad(): + out = model(input_ids) + torch.cuda.synchronize() + end = time.time() + print(f"{label}: {n_runs} runs, total {(end - start):.4f} s") + return out + + # Warmup + with torch.no_grad(): + _ = trt_model(input_ids) + + # Benchmark + trt_out = benchmark_model(trt_model, input_ids, "TensorRT model") + +# Compare outputs + +pytorch_logits = output.logits +trt_logits = trt_out.logits + +pytorch_logits = pytorch_logits.to(DEVICE) +trt_logits = trt_logits.to(DEVICE) +print("Max abs diff:", (pytorch_logits - trt_logits).abs().max().item()) +print("Mean abs diff:", (pytorch_logits - trt_logits).abs().mean().item())