Skip to content

Commit a5daf0f

Browse files
author
John Welsh
committed
Merge branch 'SrivastavaKshitij-pr442_rebase_master'
2 parents 817f937 + b55cd01 commit a5daf0f

36 files changed

+1872
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## [Master]
44

5+
- Added Quantization Aware Training (QAT) workflow to contrib
56
- Added converter for ``torch.roll``
67
- Added converter for ``torch.nn.functional.layer_norm``
78
- Added converter for ``torch.nn.functional.gelu``

CONTRIBUTORS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Below is a list of developers who have contributed to torch2trt. This is also used to track contributors
44
who have agreed to torch2trt's Contributor License Agreement.
55

6+
- [John Welsh](https://github.com/jaybdub) (CLA)
67
- John Welsh
78

89
## Becoming a Contributor
@@ -42,6 +43,6 @@ In some instances, you may be requested to sign torch2trt's Contributor License
4243
4. Make a signed commit with the following text
4344

4445
```md
45-
git commit -S -m "I have read and agree to the Contributor License Agreement as written in the file CLA.pdf of this project. Signed, <Full Name>"
46+
git commit -S -m "I have read and agree to the Contributor License Agreement as written in the file CLA.md of this project. Signed, <Full Name>"
4647
```
4748

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ cd torch2trt
115115
python setup.py install
116116
```
117117

118-
### Option 2 - With plugins (experimental)
118+
### Option 2 - With plugins
119119

120120
To install with plugins to support some operations in PyTorch that are not natviely supported with TensorRT, call the following
121121

@@ -127,6 +127,19 @@ cd torch2trt
127127
sudo python setup.py install --plugins
128128
```
129129

130+
### Option 3 - With support for experimental community contributed features
131+
132+
To install torch2trt with experimental community contributed features under ``torch2trt.contrib``, like Quantization Aware Training (QAT)(`requires TensorRT>=7.0`), call the following,
133+
134+
```bash
135+
git clone https://github.com/NVIDIA-AI-IOT/torch2trt
136+
cd torch2trt/scripts
137+
bash build_contrib.sh
138+
```
139+
140+
This enables you to run the QAT example located [here](examples/contrib/quantization_aware_training).
141+
142+
130143
## How does it work?
131144

132145
This converter works by attaching conversion functions (like ``convert_ReLU``) to the original
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
## QAT working example
2+
3+
This example is using QAT library open sourced by nvidia. [Github link](https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization)
4+
5+
## Directory overview
6+
7+
1. This directory contains
8+
1. `dataset` : contains code for cifar-10 dataset
9+
2. `layers` : contains implementation for inference. More details under `layers/README.md`
10+
3. `models`: contains two models. `resnet18` and `vanilla_cnn`
11+
4. `utils` : contains various utility functions for loading state dict, custom wrapper for training and inference & calculating accuracy during training
12+
5. `train.py` and `infer.py` : contains code for training and inference (including trt conversion)
13+
14+
2. Usually, nvidia quantization library doesn't provide control per layer for quantization. Custom wrapper under `utils/utilities.py` helps us in quantization selective layers in our model.
15+
16+
## Environment
17+
18+
**Filename** : pytorch_ngc_container_20.09
19+
20+
```
21+
FROM nvcr.io/nvidia/pytorch:20.09-py3
22+
RUN apt-get update && apt-get install -y software-properties-common && apt-get update
23+
RUN add-apt-repository ppa:git-core/ppa && \
24+
apt install -y git
25+
26+
RUN pip install termcolor graphviz
27+
28+
## If you have followed instructions on main README.md file to install torch2trt using scripts/build_contrib.sh
29+
## You dont require rest of the steps
30+
31+
RUN git clone https://github.com/NVIDIA/TensorRT.git /sw/TensorRT/
32+
33+
##Make sure that patch file is under the same folder where dockerfile is being called
34+
35+
ADD pytorch_nvidia_quantization.patch /sw/TensorRT
36+
37+
RUN cd /sw/TensorRT/ && \
38+
git sparse-checkout init --cone && \
39+
git sparse-checkout set /tools/pytorch-quantization/ && \
40+
git apply --reject --whitespace=fix pytorch_nvidia_quantization.patch && \
41+
cd tools/pytorch-quantization/ && \
42+
python setup.py install
43+
44+
RUN git clone https://github.com/NVIDIA-AI-IOT/torch2trt.git /sw/TensorRT/ && \
45+
cd /sw/TensorRT/ && \
46+
git fetch origin pull/514/head:PR514 && \
47+
git checkout PR514 && \
48+
python setup.py install --plugins
49+
50+
```
51+
52+
Docker build: `docker build -f pytorch_ngc_container_20.09 -t pytorch_ngc_container_20.09 .`
53+
54+
`docker_image=pytorch_ngc_container_20.09`
55+
56+
Docker run : `docker run -e NVIDIA_VISIBLE_DEVICES=0 --gpus 0 -it --shm-size=1g --ulimit memlock=-1 --rm -v $PWD:/workspace/work $docker_image`
57+
58+
**Important Notes** :
59+
60+
- Sparse checkout helps us in checking out a part of the github repo.
61+
- Patch file can be found under `examples/quantization_aware_training/utils`
62+
63+
## Workflow
64+
65+
Workflow consists of three parts.
66+
1. Train without quantization:
67+
68+
Here pretrained weights from imagenet are used.
69+
70+
`python train.py --m resnet34-tl / resnet18-tl --num_epochs 45 --test_trt --FP16 --INT8PTC`
71+
72+
2. Train with quantization (weights are mapped using a custom function to make sure that each weight is loaded correctly)
73+
74+
`python train.py --m resnet34/ resnet18 --netqat --partial_ckpt --tl --load_ckpt /tmp/pytorch_exp/{} --num_epochs 25 --lr 1e-4 --lrdt 10`
75+
76+
3. Infer with and without TRT
77+
78+
`python infer.py --m resnet34/resnet18 --load_ckpt /tmp/pytorch_exp_1/ckpt_{} --netqat --INT8QAT`
79+
80+
81+
## Accuracy Results
82+
83+
| Model | FP32 | FP16 | INT8 (QAT) | INT(PTC) |
84+
|-------|------|------|------------|----------|
85+
| Resnet18 | 83.08 | 83.12 | 83.12 | 83.06 |
86+
| Resnet34 | 84.65 | 84.65 | 83.26 | 84.5 |
87+
88+
89+
**Please note that the idea behind these experiments is to see if TRT conversion is working properly rather than achieving industry standard accuracy results**
90+
91+
## Future Work
92+
93+
- Add results for Resnet50, EfficientNet and Mobilenet
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .layers import *

examples/contrib/quantization_aware_training/datasets/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import torchvision
3+
import torchvision.transforms as transforms
4+
5+
class Cifar10Loaders:
6+
"""
7+
Data loaders for cifar 10 dataset
8+
"""
9+
def __init__(self, data_dir='/tmp/cifar10', download=True, batch_size=128, pin_memory=True, num_workers=4):
10+
self.data_dir = data_dir
11+
self.download = download
12+
self.batch_size= batch_size
13+
self.pin_memory = pin_memory
14+
self.num_workers = num_workers
15+
self.train_transform = transforms.Compose([
16+
transforms.RandomCrop(32, padding=4),
17+
transforms.RandomHorizontalFlip(),
18+
transforms.ToTensor(),
19+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
20+
])
21+
self.test_transform = transforms.Compose([
22+
transforms.ToTensor(),
23+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
24+
])
25+
26+
def train_loader(self,shuffle=True):
27+
trainset = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=self.train_transform)
28+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory)
29+
return trainloader
30+
31+
def test_loader(self,shuffle=False):
32+
testset = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=self.test_transform)
33+
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory)
34+
return testloader
35+
36+
37+
38+
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import timeit
2+
import torch
3+
import torch.nn as nn
4+
import numpy as np
5+
import torchvision
6+
import argparse
7+
import os,sys
8+
from datasets.cifar10 import Cifar10Loaders
9+
from utils.utilities import calculate_accuracy, timeGraph,printStats
10+
from models.resnet import resnet18,resnet34
11+
from parser import parse_args
12+
from torch2trt import torch2trt
13+
import tensorrt as trt
14+
torch.set_printoptions(precision=5)
15+
16+
def main():
17+
args = parse_args()
18+
19+
args.cuda = not args.no_cuda and torch.cuda.is_available()
20+
torch.manual_seed(78543)
21+
22+
if args.cuda:
23+
torch.backends.cudnn.benchmark = True
24+
torch.cuda.manual_seed(args.seed)
25+
26+
loaders = Cifar10Loaders()
27+
train_loader = loaders.train_loader()
28+
test_loader = loaders.test_loader()
29+
30+
if args.m == "resnet18":
31+
if args.netqat:
32+
model=resnet18(qat_mode=True,infer=True)
33+
else:
34+
model=resnet18()
35+
elif args.m == "resnet34":
36+
if args.netqat:
37+
model=resnet34(qat_mode=True,infer=True)
38+
else:
39+
model=resnet34()
40+
else:
41+
raise NotImplementedError("{} model not found".format(args.m))
42+
43+
44+
model = model.cuda().eval()
45+
46+
if args.load_ckpt:
47+
checkpoint = torch.load(args.load_ckpt)
48+
if not args.netqat:
49+
checkpoint = mapping_names_resnets(checkpoint)
50+
model.load_state_dict(checkpoint['model_state_dict'],strict=True)
51+
print("===>>> Checkpoint loaded successfully from {} ".format(args.load_ckpt))
52+
53+
test_accuracy = calculate_accuracy(model,test_loader)
54+
print(" Test accuracy for Pytorch model: {0} ".format(test_accuracy))
55+
rand_in = torch.randn([128,3,32,32],dtype=torch.float32).cuda()
56+
57+
#Converting the model to TRT
58+
if args.FP16:
59+
trt_model_fp16 = torch2trt(model,[rand_in],log_level=trt.Logger.INFO,fp16_mode=True,max_batch_size=128)
60+
test_accuracy = calculate_accuracy(trt_model_fp16,test_loader)
61+
print(" TRT test accuracy at FP16: {0}".format(test_accuracy))
62+
63+
if args.INT8QAT:
64+
trt_model_int8 = torch2trt(model,[rand_in],log_level=trt.Logger.INFO,fp16_mode=True,int8_mode=True,max_batch_size=128,qat_mode=True)
65+
test_accuracy = calculate_accuracy(trt_model_int8,test_loader)
66+
print(" TRT test accuracy at INT8 QAT: {0}".format(test_accuracy))
67+
68+
if args.INT8PTC:
69+
##preparing calib dataset
70+
calib_dataset = list()
71+
for i, sam in enumerate(test_loader):
72+
calib_dataset.extend(sam[0])
73+
if i ==5:
74+
break
75+
76+
trt_model_calib_int8 = torch2trt(model,[rand_in],log_level=trt.Logger.INFO,fp16_mode=True,int8_calib_dataset=calib_dataset,int8_mode=True,max_batch_size=128)
77+
test_accuracy = calculate_accuracy(trt_model_calib_int8,test_loader)
78+
print(" TRT test accuracy at INT8 PTC: {0}".format(test_accuracy))
79+
80+
if __name__ == "__main__":
81+
main()

examples/contrib/quantization_aware_training/models/__init__.py

Whitespace-only changes.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
'''
2+
Contains basic model definitions
3+
'''
4+
5+
import torch
6+
import torch.nn as nn
7+
from utils.utilities import qrelu,qconv2d
8+
9+
class vanilla_cnn(nn.Module):
10+
def __init__(self,qat_mode=False,infer=False):
11+
super().__init__()
12+
self.qat = qat_mode
13+
self.layer1=qconv2d(3,32,padding=1,qat=qat_mode,infer=infer)
14+
self.layer2=qconv2d(32,64,padding=1,qat=qat_mode,infer=infer)
15+
self.layer3=qconv2d(64,128,padding=1,qat=qat_mode,infer=infer)
16+
self.layer4=qconv2d(128,256,padding=1,qat=qat_mode,infer=infer)
17+
self.layer5 = nn.MaxPool2d(kernel_size=2,stride=8)
18+
self.fcs = nn.Sequential(
19+
nn.Linear(4096,1024),
20+
nn.ReLU(),
21+
nn.Linear(1024,512),
22+
nn.ReLU(),
23+
nn.Linear(512,10))
24+
25+
def forward(self,x):
26+
x = self.layer1(x)
27+
x = self.layer2(x)
28+
x = self.layer3(x)
29+
x = self.layer4(x)
30+
x = self.layer5(x)
31+
x = x.view(x.size(0),-1)
32+
x = self.fcs(x)
33+
return x
34+
35+
36+

0 commit comments

Comments
 (0)