Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 013877a

Browse files
aghinsapdxjohnny
authored andcommitted
model: tensorflow: Add regression model
* cli: Fixed cli.PredictAll.predict to yield repo * model: tensorflow: Moved predict, _model_dir, train to TensorflowModelContext Fixes: #75 Signed-off-by: John Andersen <[email protected]>
1 parent cfd37f9 commit 013877a

File tree

10 files changed

+749
-202
lines changed

10 files changed

+749
-202
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3737
there are more than 5 issues of high severity and confidence.
3838
- dev service got the ability to run a single operation in a standalone fashion.
3939
- About page to docs.
40+
- Tensorflow DNNEstimator based regression model.
4041
### Changed
4142
- feature/codesec became it's own branch, binsec
4243
- BaseOrchestratorContext `run_operations` strict is default to true. With
@@ -68,6 +69,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6869
prediction as well. Models are not responsible for calling the predicted
6970
method on the repo. This will ease the process of making predict feature
7071
specific.
72+
- Updated Tensorflow model README.md to include usage of regression model
7173
### Fixed
7274
- Docs get version from dffml.version.VERSION.
7375
- FileSource zipfiles are wrapped with TextIOWrapper because CSVSource expects

dffml/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,7 @@ class PredictAll(EvaluateAll, MLCMD):
408408
"""Predicts for all sources"""
409409

410410
async def predict(self, mctx, sctx, repos):
411-
async for repo, value, confidence in mctx.predict(repos):
412-
repo.predicted(value, confidence)
411+
async for repo in mctx.predict(repos):
413412
yield repo
414413
if self.update:
415414
await sctx.update(repo)

docs/plugins/dffml_model.rst

Lines changed: 168 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@ dffml_model_tensorflow
1717
pip install dffml-model-tensorflow
1818
1919
20+
.. note::
21+
22+
It's important to keep the hidden layer config and feature config the same
23+
across invocations of train, predict, and accuracy methods.
24+
25+
Models are saved under the ``directory`` parameter in subdirectories named
26+
after the hash of their feature names and hidden layer config. Which means
27+
if any of those parameters change between invocations, it's being told to
28+
look for a different saved model.
29+
2030
tfdnnc
2131
~~~~~~
2232

2333
*Core*
2434

25-
Implemented using Tensorflow's DNNClassifier. Models are saved under the
26-
``directory`` in subdirectories named after the hash of their feature names.
35+
Implemented using Tensorflow's DNNClassifier.
2736

2837
.. code-block:: console
2938
@@ -33,49 +42,49 @@ Implemented using Tensorflow's DNNClassifier. Models are saved under the
3342
$ sed -i 's/.*setosa,versicolor,virginica/SepalLength,SepalWidth,PetalLength,PetalWidth,classification/g' *.csv
3443
$ head iris_training.csv
3544
$ dffml train \
36-
-model tfdnnc \
37-
-model-epochs 3000 \
38-
-model-steps 20000 \
39-
-model-classification classification \
40-
-model-classifications 0 1 2 \
41-
-model-clstype int \
42-
-sources iris=csv \
43-
-source-filename iris_training.csv \
44-
-features \
45-
def:SepalLength:float:1 \
46-
def:SepalWidth:float:1 \
47-
def:PetalLength:float:1 \
48-
def:PetalWidth:float:1 \
49-
-log debug
45+
-model tfdnnc \
46+
-model-epochs 3000 \
47+
-model-steps 20000 \
48+
-model-classification classification \
49+
-model-classifications 0 1 2 \
50+
-model-clstype int \
51+
-sources iris=csv \
52+
-source-filename iris_training.csv \
53+
-features \
54+
def:SepalLength:float:1 \
55+
def:SepalWidth:float:1 \
56+
def:PetalLength:float:1 \
57+
def:PetalWidth:float:1 \
58+
-log debug
5059
... lots of output ...
5160
$ dffml accuracy \
52-
-model tfdnnc \
53-
-model-classification classification \
54-
-model-classifications 0 1 2 \
55-
-model-clstype int \
56-
-sources iris=csv \
57-
-source-filename iris_test.csv \
58-
-features \
59-
def:SepalLength:float:1 \
60-
def:SepalWidth:float:1 \
61-
def:PetalLength:float:1 \
62-
def:PetalWidth:float:1 \
63-
-log critical
61+
-model tfdnnc \
62+
-model-classification classification \
63+
-model-classifications 0 1 2 \
64+
-model-clstype int \
65+
-sources iris=csv \
66+
-source-filename iris_test.csv \
67+
-features \
68+
def:SepalLength:float:1 \
69+
def:SepalWidth:float:1 \
70+
def:PetalLength:float:1 \
71+
def:PetalWidth:float:1 \
72+
-log critical
6473
0.99996233782
6574
$ dffml predict all \
66-
-model tfdnnc \
67-
-model-classification classification \
68-
-model-classifications 0 1 2 \
69-
-model-clstype int \
70-
-sources iris=csv \
71-
-source-filename iris_test.csv \
72-
-features \
73-
def:SepalLength:float:1 \
74-
def:SepalWidth:float:1 \
75-
def:PetalLength:float:1 \
76-
def:PetalWidth:float:1 \
77-
-caching \
78-
-log critical \
75+
-model tfdnnc \
76+
-model-classification classification \
77+
-model-classifications 0 1 2 \
78+
-model-clstype int \
79+
-sources iris=csv \
80+
-source-filename iris_test.csv \
81+
-features \
82+
def:SepalLength:float:1 \
83+
def:SepalWidth:float:1 \
84+
def:PetalLength:float:1 \
85+
def:PetalWidth:float:1 \
86+
-caching \
87+
-log critical \
7988
> results.json
8089
$ head -n 33 results.json
8190
[
@@ -147,6 +156,124 @@ Implemented using Tensorflow's DNNClassifier. Models are saved under the
147156
- default: <class 'str'>
148157
- Data type of classifications values (default: str)
149158

159+
tfdnnr
160+
~~~~~~
161+
162+
*Core*
163+
164+
Implemented using Tensorflow's DNNEstimator.
165+
166+
Usage:
167+
168+
* predict: Name of the feature we are trying to predict or using for training.
169+
170+
Generating train and test data
171+
172+
* This creates files `train.csv` and `test.csv`,
173+
make sure to take a BACKUP of files with same name in the directory
174+
from where this command is run as it overwrites any existing files.
175+
176+
.. code-block:: console
177+
178+
$ cat > train.csv << EOF
179+
Feature1,Feature2,TARGET
180+
0.93,0.68,3.89
181+
0.24,0.42,1.75
182+
0.36,0.68,2.75
183+
0.53,0.31,2.00
184+
0.29,0.25,1.32
185+
0.29,0.52,2.14
186+
EOF
187+
$ cat > test.csv << EOF
188+
Feature1,Feature2,TARGET
189+
0.57,0.84,3.65
190+
0.95,0.19,2.46
191+
0.23,0.15,0.93
192+
EOF
193+
$ dffml train \
194+
-model tfdnnr \
195+
-model-epochs 300 \
196+
-model-steps 2000 \
197+
-model-predict TARGET \
198+
-model-hidden 8 16 8 \
199+
-sources s=csv \
200+
-source-readonly \
201+
-source-filename train.csv \
202+
-features \
203+
def:Feature1:float:1 \
204+
def:Feature2:float:1 \
205+
-log debug
206+
Enabling debug log shows tensorflow losses...
207+
$ dffml accuracy \
208+
-model tfdnnr \
209+
-model-predict TARGET \
210+
-model-hidden 8 16 8 \
211+
-sources s=csv \
212+
-source-readonly \
213+
-source-filename test.csv \
214+
-features \
215+
def:Feature1:float:1 \
216+
def:Feature2:float:1 \
217+
-log critical
218+
0.9468210011
219+
$ echo -e 'Feature1,Feature2,TARGET\n0.21,0.18,0.84\n' | \
220+
dffml predict all \
221+
-model tfdnnr \
222+
-model-predict TARGET \
223+
-model-hidden 8 16 8 \
224+
-sources s=csv \
225+
-source-readonly \
226+
-source-filename /dev/stdin \
227+
-features \
228+
def:Feature1:float:1 \
229+
def:Feature2:float:1 \
230+
-log critical
231+
[
232+
{
233+
"extra": {},
234+
"features": {
235+
"Feature1": 0.21,
236+
"Feature2": 0.18,
237+
"TARGET": 0.84
238+
},
239+
"last_updated": "2019-10-24T15:26:41Z",
240+
"prediction": {
241+
"confidence": NaN,
242+
"value": 1.1983429193496704
243+
},
244+
"src_url": 0
245+
}
246+
]
247+
248+
The ``NaN`` in ``confidence`` is the expected behaviour. (See TODO in
249+
predict).
250+
251+
**Args**
252+
253+
- directory: String
254+
255+
- default: /home/user/.cache/dffml/tensorflow
256+
- Directory where state should be saved
257+
258+
- steps: Integer
259+
260+
- default: 3000
261+
- Number of steps to train the model
262+
263+
- epochs: Integer
264+
265+
- default: 30
266+
- Number of iterations to pass over all repos in a source
267+
268+
- hidden: List of integers
269+
270+
- default: [12, 40, 15]
271+
- List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer
272+
273+
- predict: String
274+
275+
- Feature name holding truth value
276+
150277
dffml_model_scratch
151278
-------------------
152279

model/tensorflow/README.md

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
# DFFML Models for Tensorflow Library
22

3-
## About
4-
5-
DFFML models backed by Tensorflow.
6-
73
## Demo
84

95
![Demo](https://github.com/intel/dffml/raw/master/docs/images/iris_demo.gif)
@@ -12,69 +8,21 @@ DFFML models backed by Tensorflow.
128
> may vary as this video shows accuracy being assessed against the training
139
> data. You should try it for yourself and see!
1410
15-
## Install
16-
17-
```console
18-
virtualenv -p python3.7 .venv
19-
. .venv/bin/activate
20-
python3.7 -m pip install --user -U dffml[tensorflow]
21-
```
11+
## Documentation
2212

23-
## Usage
24-
25-
```console
26-
wget http://download.tensorflow.org/data/iris_training.csv
27-
wget http://download.tensorflow.org/data/iris_test.csv
28-
head iris_training.csv
29-
sed -i 's/.*setosa,versicolor,virginica/SepalLength,SepalWidth,PetalLength,PetalWidth,classification/g' *.csv
30-
head iris_training.csv
31-
dffml train \
32-
-model tfdnnc \
33-
-model-epochs 3000 \
34-
-model-steps 20000 \
35-
-model-classification classification \
36-
-model-classifications 0 1 2 \
37-
-model-clstype int \
38-
-sources iris=csv \
39-
-source-filename iris_training.csv \
40-
-features \
41-
def:SepalLength:float:1 \
42-
def:SepalWidth:float:1 \
43-
def:PetalLength:float:1 \
44-
def:PetalWidth:float:1 \
45-
-log debug
46-
dffml accuracy \
47-
-model tfdnnc \
48-
-model-classification classification \
49-
-model-classifications 0 1 2 \
50-
-model-clstype int \
51-
-sources iris=csv \
52-
-source-filename iris_test.csv \
53-
-features \
54-
def:SepalLength:float:1 \
55-
def:SepalWidth:float:1 \
56-
def:PetalLength:float:1 \
57-
def:PetalWidth:float:1 \
58-
-log critical
59-
dffml predict all \
60-
-model tfdnnc \
61-
-model-classification classification \
62-
-model-classifications 0 1 2 \
63-
-model-clstype int \
64-
-sources iris=csv \
65-
-source-filename iris_test.csv \
66-
-features \
67-
def:SepalLength:float:1 \
68-
def:SepalWidth:float:1 \
69-
def:PetalLength:float:1 \
70-
def:PetalWidth:float:1 \
71-
-caching \
72-
-log critical \
73-
> results.json
74-
head -n 33 results.json
75-
```
13+
Documentation is hosted at https://intel.github.io/dffml/plugins/dffml_model.html#dffml-model-tensorflow
7614

7715
## License
7816

7917
DFFML Tensorflow Models are distributed under the terms of the
8018
[MIT License](LICENSE).
19+
20+
## Legal
21+
22+
> This software is subject to the U.S. Export Administration Regulations and
23+
> other U.S. law, and may not be exported or re-exported to certain countries
24+
> (Cuba, Iran, Crimea Region of Ukraine, North Korea, Sudan, and Syria) or to
25+
> persons or entities prohibited from receiving U.S. exports (including
26+
> Denied Parties, Specially Designated Nationals, and entities on the Bureau
27+
> of Export Administration Entity List or involved with missile technology or
28+
> nuclear, chemical or biological weapons).
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
.. note::
3+
4+
It's important to keep the hidden layer config and feature config the same
5+
across invocations of train, predict, and accuracy methods.
6+
7+
Models are saved under the ``directory`` parameter in subdirectories named
8+
after the hash of their feature names and hidden layer config. Which means
9+
if any of those parameters change between invocations, it's being told to
10+
look for a different saved model.
11+
"""

0 commit comments

Comments
 (0)