Skip to content

Commit 532b386

Browse files
committed
test(): added interpolate model for engine serialization testing
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 6308190 commit 532b386

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tests/modules/hub.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
24
import torchvision.models as models
35

46
models = {
@@ -65,6 +67,7 @@
6567
}
6668
}
6769

70+
# Download sample models
6871
for n, m in models.items():
6972
print("Downloading {}".format(n))
7073
m["model"] = m["model"].eval().cuda()
@@ -74,4 +77,20 @@
7477
torch.jit.save(trace_model, n + '_traced.jit.pt')
7578
if m["path"] == "both" or m["path"] == "script":
7679
script_model = torch.jit.script(m["model"])
77-
torch.jit.save(script_model, n + '_scripted.jit.pt')
80+
torch.jit.save(script_model, n + '_scripted.jit.pt')
81+
82+
# Sample Interpolation Model (align_corners=False, for Testing Interpolate Plugin specifically)
83+
class Interpolate(nn.Module):
84+
def __init__(self):
85+
super(Interpolate, self).__init__()
86+
87+
def forward(self, x):
88+
return F.interpolate(x, size=(10, 10, 10), align_corners=False, mode="trilinear")
89+
90+
model = Interpolate().eval().cuda()
91+
x = torch.ones([1, 3, 5, 5, 5]).cuda()
92+
93+
trace_model = torch.jit.trace(model, x)
94+
torch.jit.save(trace_model, "interpolate_traced.jit.pt")
95+
96+

0 commit comments

Comments
 (0)