@@ -34,7 +34,7 @@ Here's an example usage
3434 model = MyModel().eval().cuda()
3535 inputs = [torch.randn((1 , 3 , 224 , 224 )).cuda()]
3636 # trt_ep is a torch.fx.GraphModule object
37- trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs)
37+ trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs = inputs)
3838 torch_tensorrt.save(trt_gm, " trt.ep" , inputs = inputs)
3939
4040 # Later, you can load it and run inference
@@ -52,7 +52,7 @@ b) Torchscript
5252 model = MyModel().eval().cuda()
5353 inputs = [torch.randn((1 , 3 , 224 , 224 )).cuda()]
5454 # trt_gm is a torch.fx.GraphModule object
55- trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs)
55+ trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs = inputs )
5656 torch_tensorrt.save(trt_gm, " trt.ts" , output_format = " torchscript" , inputs = inputs)
5757
5858 # Later, you can load it and run inference
@@ -73,7 +73,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
7373
7474 model = MyModel().eval().cuda()
7575 inputs = [torch.randn((1 , 3 , 224 , 224 )).cuda()]
76- trt_ts = torch_tensorrt.compile(model, ir = " ts" , inputs) # Output is a ScriptModule object
76+ trt_ts = torch_tensorrt.compile(model, ir = " ts" , inputs = inputs ) # Output is a ScriptModule object
7777 torch.jit.save(trt_ts, " trt_model.ts" )
7878
7979 # Later, you can load it and run inference
0 commit comments