Skip to content

Commit 06668a2

Browse files
author
Nikil Ravi
committed
modified file paths
1 parent bd1b88c commit 06668a2

File tree

5 files changed

+81
-33
lines changed

5 files changed

+81
-33
lines changed

10_FAIR_AI/trt_and_containerization/build_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import tensorrt as trt
55

66
name = 'cifar_net_10_epochs'
7-
engine_name = '{}_4.plan'.format(name)
8-
onnx_path = '/home/nravi/ai-science-training-series/{}.onnx'.format(name)
7+
engine_name = '{}_4_trial.plan'.format(name)
8+
onnx_path = '/home/nravi/ai-science-training-series/10_FAIR_AI/trt_and_containerization/saved_models/{}.onnx'.format(name)
99
batch_size = 4
1010

1111
model = ModelProto()

10_FAIR_AI/trt_and_containerization/cifar-model-for-trt.ipynb

Lines changed: 71 additions & 24 deletions
Large diffs are not rendered by default.

10_FAIR_AI/trt_and_containerization/cifar-script-torch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def forward(self, x):
2828
return x
2929

3030

31-
PATH = './cifar_net_10_epochs.pth'
31+
PATH = './saved_models/cifar_net_10_epochs.pth'
3232
net = Net()
3333
net.load_state_dict(torch.load(PATH))
3434

@@ -37,13 +37,13 @@ def forward(self, x):
3737
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
3838

3939
batch_size = 250
40-
41-
trainset = torchvision.datasets.CIFAR10(root='./', train=True,
40+
data_dir = '../../../ai-science-training-series-old/'
41+
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True,
4242
download=False, transform=transform)
4343
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
4444
shuffle=True, num_workers=2)
4545

46-
testset = torchvision.datasets.CIFAR10(root='./', train=False,
46+
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False,
4747
download=False, transform=transform)
4848
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
4949
shuffle=False, num_workers=2)

10_FAIR_AI/trt_and_containerization/cifar-script-trt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,20 @@ def do_inference(engine, pics_1, h_input_1, d_input_1, h_output, d_output, strea
9595

9696
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
9797
trt_runtime = trt.Runtime(TRT_LOGGER)
98-
engine = load_engine(trt_runtime, '/home/nravi/ai-science-training-series/cifar_net_10_epochs_250.plan')
98+
engine = load_engine(trt_runtime,'./saved_models/cifar_net_10_epochs_250.plan')
9999
print("Loaded engine")
100100

101101
transform = transforms.Compose(
102102
[transforms.ToTensor(),
103103
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
104104

105105
batch_size = 250
106-
trainset = torchvision.datasets.CIFAR10(root='./', train=True,
106+
data_dir = '../../../ai-science-training-series-old/'
107+
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True,
107108
download=False, transform=transform)
108109
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
109110
shuffle=True, num_workers=2)
110-
testset = torchvision.datasets.CIFAR10(root='./', train=False,
111+
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False,
111112
download=False, transform=transform)
112113
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
113114
shuffle=False, num_workers=2)
Binary file not shown.

0 commit comments

Comments
 (0)