Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit a08bd6a

Browse files
model: pytorch: Add PyTorch based pre-trained ConvNet models
Signed-off-by: sakshamarora1 <[email protected]>
1 parent 19624c7 commit a08bd6a

File tree

24 files changed

+1086
-2
lines changed

24 files changed

+1086
-2
lines changed

.github/workflows/testing.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
fail-fast: false
6060
max-parallel: 40
6161
matrix:
62-
plugin: [., examples/shouldi, model/daal4py, model/tensorflow, model/tensorflow_hub, model/transformers, model/scratch, model/scikit, model/vowpalWabbit, model/autosklearn, model/spacy, operations/binsec, operations/deploy, operations/image, operations/nlp, source/mysql, feature/git, feature/auth, service/http, configloader/yaml, configloader/image]
62+
plugin: [., examples/shouldi, model/daal4py, model/tensorflow, model/tensorflow_hub, model/transformers, model/scratch, model/scikit, model/vowpalWabbit, model/autosklearn, model/spacy, model/pytorch, operations/binsec, operations/deploy, operations/image, operations/nlp, source/mysql, feature/git, feature/auth, service/http, configloader/yaml, configloader/image]
6363
python-version: [3.7, 3.8]
6464

6565
steps:
@@ -119,6 +119,7 @@ jobs:
119119
model/scikit=${{ secrets.PYPI_MODEL_SCIKIT }}
120120
model/vowpalWabbit=${{ secrets.PYPI_MODEL_VOWPALWABBIT }}
121121
model/autosklearn=${{ secrets.PYPI_MODEL_AUTOSKLEARN }}
122+
model/pytorch=${{ secrets.PYPI_MODEL_PYTORCH }}
122123
source/mysql=${{ secrets.PYPI_SOURCE_MYSQL }}
123124
feature/git=${{ secrets.PYPI_FEATURE_GIT }}
124125
feature/auth=${{ secrets.PYPI_FEATURE_AUTH }}

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88
### Added
9+
- Pre-Trained PyTorch torchvision Models
910
- Spacy model for NER
1011
- Added ability to rename outputs using GetSingle
1112
- Tutorial for using NLP operations with models

dffml/plugins.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def inpath(binary):
3434
("model", "transformers"),
3535
("model", "vowpalWabbit"),
3636
("model", "autosklearn"),
37+
("model", "pytorch"),
3738
("model", "spacy"),
3839
]
3940

model/pytorch/.coveragerc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[run]
2+
source =
3+
dffml_model_pytorch
4+
tests
5+
branch = True
6+
7+
[report]
8+
exclude_lines =
9+
no cov
10+
no qa
11+
noqa
12+
pragma: no cover
13+
if __name__ == .__main__.:

model/pytorch/.gitignore

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
*.log
2+
*.pyc
3+
.cache/
4+
.coverage
5+
.idea/
6+
.vscode/
7+
*.egg-info/
8+
build/
9+
dist/
10+
docs/build/
11+
venv/
12+
wheelhouse/
13+
*.egss
14+
.mypy_cache/
15+
*.swp
16+
.venv/
17+
.eggs/
18+
*.modeldir
19+
*.db
20+
htmlcov/

model/pytorch/LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
Copyright (c) 2020 Intel, Saksham
2+
3+
MIT License
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

model/pytorch/MANIFEST.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
include README.md
2+
include LICENSE
3+
include setup_common.py

model/pytorch/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# DFFML PyTorch Models
2+
3+
## About
4+
5+
dffml_model_pytorch supports pre-trained PyTorch ConvNet models.
6+
![Pre-Trained models](https://pytorch.org/docs/stable/torchvision/models.html)
7+
8+
## Documentation
9+
10+
Documentation is hosted at https://intel.github.io/dffml/plugins/dffml_model.html#dffml-model-pytorch
11+
12+
## License
13+
14+
dffml_model_pytorch Models are distributed under the terms of the
15+
[MIT License](LICENSE).
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
Machine Learning models implemented with `PyTorch <https://pytorch.org/>`_.
3+
Models are saved under the directory in `model.pt`.
4+
5+
**General Usage:**
6+
7+
Training:
8+
9+
.. code-block:: console
10+
11+
$ dffml train \\
12+
-model PYTORCH_MODEL_ENTRYPOINT \\
13+
-model-features FEATURE_DEFINITION \\
14+
-model-predict TO_PREDICT \\
15+
-model-directory MODEL_DIRECTORY \\
16+
-model-CONFIGS CONFIG_VALUES \\
17+
-sources f=TRAINING_DATA_SOURCE_TYPE \\
18+
-source-CONFIGS TRAINING_DATA \\
19+
-log debug
20+
21+
Testing and Accuracy:
22+
23+
.. code-block:: console
24+
25+
$ dffml accuracy \\
26+
-model PYTORCH_MODEL_ENTRYPOINT \\
27+
-model-features FEATURE_DEFINITION \\
28+
-model-predict TO_PREDICT \\
29+
-model-directory MODEL_DIRECTORY \\
30+
-model-CONFIGS CONFIG_VALUES \\
31+
-sources f=TESTING_DATA_SOURCE_TYPE \\
32+
-source-CONFIGS TESTING_DATA \\
33+
-log debug
34+
35+
Predicting with trained model:
36+
37+
.. code-block:: console
38+
39+
$ dffml predict all \\
40+
-model PYTORCH_MODEL_ENTRYPOINT \\
41+
-model-features FEATURE_DEFINITION \\
42+
-model-predict TO_PREDICT \\
43+
-model-directory MODEL_DIRECTORY \\
44+
-model-CONFIGS CONFIG_VALUES \\
45+
-sources f=PREDICT_DATA_SOURCE_TYPE \\
46+
-source-CONFIGS PREDICTION_DATA \\
47+
-log debug
48+
49+
50+
**Pre-Trained Models Available:**
51+
52+
+----------------+---------------------------------+--------------------+--------------------------------------------------------------------------------+
53+
| Type | Model | Entrypoint | Architecture |
54+
+================+=================================+====================+================================================================================+
55+
| Classification | AlexNet | alexnet | `AlexNet architecture <https://arxiv.org/abs/1404.5997>`_ |
56+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
57+
| | DenseNet-121 | densenet121 | `DenseNet architecture <https://arxiv.org/pdf/1608.06993.pdf>`_ |
58+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
59+
| | DenseNet-161 | densenet161 | |
60+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
61+
| | DenseNet-169 | densenet169 | |
62+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
63+
| | DenseNet-201 | densenet201 | |
64+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
65+
| | MnasNet 0.5 | mnasnet0_5 | `MnasNet architecture <https://arxiv.org/pdf/1807.11626.pdf>`_ |
66+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
67+
| | MnasNet 1.0 | mnasnet1_0 | |
68+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
69+
| | MobileNet V2 | mobilenet_v2 | `MobileNet V2 architecture <https://arxiv.org/abs/1801.04381>`_ |
70+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
71+
| | VGG-11 | vgg11 | `VGG-11 architecture Configuration "A" <https://arxiv.org/pdf/1409.1556.pdf>`_ |
72+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
73+
| | VGG-11 with batch normalization | vgg11_bn | |
74+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
75+
| | VGG-13 | vgg13 | `VGG-13 architecture Configuration "B" <https://arxiv.org/pdf/1409.1556.pdf>`_ |
76+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
77+
| | VGG-13 with batch normalization | vgg13_bn | |
78+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
79+
| | VGG-16 | vgg16 | `VGG-16 architecture Configuration "D" <https://arxiv.org/pdf/1409.1556.pdf>`_ |
80+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
81+
| | VGG-16 with batch normalization | vgg16_bn | |
82+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
83+
| | VGG-19 | vgg19 | `VGG-19 architecture Configuration "E" <https://arxiv.org/pdf/1409.1556.pdf>`_ |
84+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
85+
| | VGG-19 with batch normalization | vgg19_bn | |
86+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
87+
| | GoogleNet | googlenet | `GoogleNet architecture <http://arxiv.org/abs/1409.4842>`_ |
88+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
89+
| | Inception V3 | inception_v3 | `Inception V3 architecture <http://arxiv.org/abs/1512.00567>`_ |
90+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
91+
| | ResNet-18 | resnet18 | `ResNet architecture <https://arxiv.org/pdf/1512.03385.pdf>`_ |
92+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
93+
| | ResNet-34 | resnet34 | |
94+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
95+
| | ResNet-50 | resnet50 | |
96+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
97+
| | ResNet-101 | resnet101 | |
98+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
99+
| | ResNet-152 | resnet152 | |
100+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
101+
| | Wide ResNet-101-2 | wide_resnet101_2 | `Wide Resnet architecture <https://arxiv.org/pdf/1605.07146.pdf>`_ |
102+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
103+
| | Wide ResNet-50-2 | wide_resnet50_2 | |
104+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
105+
| | ShuffleNet V2 0.5 | shufflenet_v2_x0_5 | `Shuffle Net V2 architecture <https://arxiv.org/abs/1807.11164>`_ |
106+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
107+
| | ShuffleNet V2 1.0 | shufflenet_v2_x1_0 | |
108+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
109+
| | ResNext-101-32x8D | resnext101_32x8d | `ResNext architecture <https://arxiv.org/pdf/1611.05431.pdf>`_ |
110+
| +---------------------------------+--------------------+--------------------------------------------------------------------------------+
111+
| | ResNext-50-32x4D | resnext50_32x4d | |
112+
+----------------+---------------------------------+--------------------+--------------------------------------------------------------------------------+
113+
114+
115+
**Usage Example:**
116+
117+
Example below uses ResNet-18 Model using the command line.
118+
119+
Let us take a simple example: **Classifying Ants and Bees Images**
120+
121+
First, we download the dataset and verify with ``sha384sum``
122+
123+
.. code-block::
124+
125+
curl -LO https://download.pytorch.org/tutorial/hymenoptera_data.zip
126+
sha384sum -c - << EOF
127+
491db45cfcab02d99843fbdcf0574ecf99aa4f056d52c660a39248b5524f9e6e8f896d9faabd27ffcfc2eaca0cec6f39 /home/tron/Desktop/Development/hymenoptera_data.zip
128+
EOF
129+
hymenoptera_data.zip: OK
130+
131+
Unzip the file
132+
133+
.. code-block::
134+
135+
unzip hymenoptera_data.zip
136+
137+
Train the model
138+
139+
.. literalinclude:: /../model/pytorch/examples/resnet18/train.sh
140+
141+
Assess accuracy
142+
143+
.. literalinclude:: /../model/pytorch/examples/resnet18/accuracy.sh
144+
145+
Output:
146+
147+
.. code-block::
148+
149+
0.9215686274509803
150+
151+
Create a csv file with the names of the images to predict, whether they are ants or bees.
152+
153+
.. literalinclude:: /../model/pytorch/examples/resnet18/unknown_data.sh
154+
155+
Make the predictions
156+
157+
.. literalinclude:: /../model/pytorch/examples/resnet18/predict.sh
158+
159+
Output:
160+
161+
.. literalinclude:: /../model/pytorch/examples/resnet18/output.txt
162+
163+
"""

0 commit comments

Comments
 (0)