Skip to content

Commit 1376536

Browse files
committed
Updated README.md
Signed-off-by: Bo Wang <[email protected]>
1 parent 8b6d80c commit 1376536

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

examples/README.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,22 +106,31 @@ converter and use it in our application:
106106
import torch
107107
import trtorch
108108

109+
# After "python3 setup install", you should find this .so file under generated "build" directory
109110
torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so')
110111

112+
111113
class Elu(torch.nn.Module):
114+
112115
def __init__(self):
113116
super(Elu, self).__init__()
114117
self.elu = torch.nn.ELU()
115118

116119
def forward(self, x):
117120
return self.elu(x)
118121

122+
123+
def MaxDiff(pytorch_out, trtorch_out):
124+
diff = torch.sub(pytorch_out, trtorch_out)
125+
abs_diff = torch.abs(diff)
126+
max_diff = torch.max(abs_diff)
127+
print("Maximum differnce between TRTorch and PyTorch: \n", max_diff)
128+
129+
119130
def main():
120-
data = torch.randn((1, 1, 2, 2)).to("cuda")
121131
model = Elu().eval() #.cuda()
122132

123133
scripted_model = torch.jit.script(model)
124-
print(scripted_model.graph)
125134
compile_settings = {
126135
"input_shapes": [{
127136
"min": [1024, 1, 32, 32],
@@ -133,10 +142,14 @@ def main():
133142
}
134143
trt_ts_module = trtorch.compile(scripted_model, compile_settings)
135144
input_data = torch.randn((1024, 1, 32, 32))
136-
print(input_data[0, :, :, 0])
137145
input_data = input_data.half().to("cuda")
138-
result = trt_ts_module(input_data)
139-
print(result[0, :, :, 0])
146+
pytorch_out = model.forward(input_data)
147+
148+
trtorch_out = trt_ts_module(input_data)
149+
print('PyTorch output: \n', pytorch_out[0, :, :, 0])
150+
print('TRTorch output: \n', trtorch_out[0, :, :, 0])
151+
MaxDiff(pytorch_out, trtorch_out)
152+
140153

141154
if __name__ == "__main__":
142155
main()

0 commit comments

Comments
 (0)