Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit a9da2d5

Browse files
authored
model: tensorflow hub: Add NLP model
* bert and nnlm embedding * model: tensorflow: util: Make tensorflow config Signed-off-by: John Andersen <[email protected]>
1 parent 1c9a228 commit a9da2d5

File tree

31 files changed

+1964
-10
lines changed

31 files changed

+1964
-10
lines changed

.ci/run.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ function run_plugin() {
2020

2121
"${PYTHON}" -m pip install -U "${SRC_ROOT}"
2222

23+
if [ "x${PLUGIN}" = "xmodel/tensorflow_hub" ]; then
24+
"${PYTHON}" -m pip install -U "${SRC_ROOT}/model/tensorflow"
25+
fi
26+
2327
cd "${PLUGIN}"
2428
PACKAGE_NAME=$(dffml service dev setuppy kwarg name setup.py)
2529
"${PYTHON}" -m pip install -e .

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
fail-fast: false
3838
max-parallel: 40
3939
matrix:
40-
plugin: [., examples/shouldi, model/tensorflow, model/scratch, model/scikit, source/mysql, feature/git, feature/auth, service/http, config/yaml]
40+
plugin: [., examples/shouldi, model/tensorflow, model/tensorflow_hub, model/scratch, model/scikit, source/mysql, feature/git, feature/auth, service/http, config/yaml]
4141
python-version: [3.7]
4242

4343
steps:

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88
### Added
9+
- Tensorflow hub NLP models.
910
- Notes on development dependencies in `setup.py` files to codebase notes.
1011

1112
## [0.3.3] - 2020-02-10

dffml/service/dev.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
("model", "tensorflow"),
4747
("model", "scratch"),
4848
("model", "scikit"),
49+
("model", "tensorflow_hub"),
4950
("examples", "shouldi"),
5051
("feature", "git"),
5152
("feature", "auth"),

docs/plugins/dffml_model.rst

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,179 @@ predict).
297297
- default: [12, 40, 15]
298298
- List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer
299299

300+
dffml_model_tensorflow_hub
301+
--------------------------
302+
303+
.. code-block:: console
304+
305+
pip install dffml-model-tensorflow-hub
306+
307+
308+
text_classifier
309+
~~~~~~~~~~~~~~~
310+
311+
*Core*
312+
313+
Implemented using Tensorflow hub pretrained models.
314+
315+
.. code-block:: console
316+
317+
$ cat > train.csv << EOF
318+
sentence,sentiment
319+
Life is good,1
320+
This book is amazing,1
321+
It's a terrible movie,0
322+
Global warming is bad,0
323+
EOF
324+
$ cat > test.csv << EOF
325+
sentence,sentiment
326+
I am not feeling good,0
327+
Our trip was full of adventures,1
328+
EOF
329+
$ dffml train \
330+
-model text_classifier \
331+
-model-epochs 30 \
332+
-model-predict sentiment:int:1 \
333+
-model-classifications 0 1 \
334+
-model-clstype int \
335+
-sources f=csv \
336+
-source-filename train.csv \
337+
-model-features \
338+
sentence:str:1 \
339+
-model-model_path "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1" \
340+
-model-add_layers \
341+
-model-layers "Dense(units=512, activation='relu')" "Dense(units=2, activation='softmax')" \
342+
-log debug
343+
$ dffml accuracy \
344+
-model text_classifier \
345+
-model-predict sentiment:int:1 \
346+
-model-classifications 0 1 \
347+
-model-model_path "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1" \
348+
-model-clstype int \
349+
-sources f=csv \
350+
-source-filename test.csv \
351+
-model-features \
352+
sentence:str:1 \
353+
-log critical
354+
1.0
355+
$ dffml predict all \
356+
-model text_classifier \
357+
-model-predict sentiment:int:1 \
358+
-model-classifications 0 1 \
359+
-model-model_path "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1" \
360+
-model-clstype int \
361+
-sources f=csv \
362+
-source-filename test.csv \
363+
-model-features \
364+
sentence:str:1 \
365+
-log debug
366+
[
367+
{
368+
"extra": {},
369+
"features": {
370+
"sentence": "I am not feeling good",
371+
"sentiment": 0
372+
},
373+
"key": "0",
374+
"last_updated": "2020-02-15T02:54:02Z",
375+
"prediction": {
376+
"sentiment": {
377+
"confidence": 0.7630850076675415,
378+
"value": 0
379+
}
380+
}
381+
},
382+
{
383+
"extra": {},
384+
"features": {
385+
"sentence": "Our trip was full of adventures",
386+
"sentiment": 1
387+
},
388+
"key": "1",
389+
"last_updated": "2020-02-15T02:54:02Z",
390+
"prediction": {
391+
"sentiment": {
392+
"confidence": 0.6673157811164856,
393+
"value": 1
394+
}
395+
}
396+
}
397+
]
398+
399+
**Args**
400+
401+
- predict: Feature
402+
403+
- Feature name holding classification value
404+
405+
- classifications: List of strings
406+
407+
- Options for value of classification
408+
409+
- features: List of features
410+
411+
- Features to train on
412+
413+
- trainable: String
414+
415+
- default: True
416+
- Tweak pretrained model by training again
417+
418+
- batch_size: Integer
419+
420+
- default: 120
421+
- Batch size
422+
423+
- max_seq_length: Integer
424+
425+
- default: 256
426+
- Length of sentence, used in preprocessing of input for bert embedding
427+
428+
- add_layers: Boolean
429+
430+
- default: False
431+
- Add layers on the top of pretrianed model/layer
432+
433+
- embedType: String
434+
435+
- default: None
436+
- Type of pretrained embedding model, required to be set to `bert` to use bert pretrained embedding
437+
438+
- layers: List of strings
439+
440+
- default: None
441+
- Extra layers to be added on top of pretrained model
442+
443+
- model_path: String
444+
445+
- default: https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1
446+
- Pretrained model path/url
447+
448+
- optimizer: String
449+
450+
- default: adam
451+
- Optimizer used by model
452+
453+
- metrics: String
454+
455+
- default: accuracy
456+
- Metric used to evaluate model
457+
458+
- clstype: Type
459+
460+
- default: <class 'str'>
461+
- Data type of classifications values
462+
463+
- epochs: Integer
464+
465+
- default: 10
466+
- Number of iterations to pass over all repos in a source
467+
468+
- directory: String
469+
470+
- default: /home/user/.cache/dffml/tensorflow_hub
471+
- Directory where state should be saved
472+
300473
dffml_model_scratch
301474
-------------------
302475

model/scikit/tests/test_scikit_integration.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@
99
import contextlib
1010

1111
import numpy as np
12+
from sklearn.datasets import make_blobs
1213

1314
from dffml.cli.cli import CLI
1415
from dffml.util.asynctestcase import IntegrationCLITestCase
1516

1617

17-
from sklearn.datasets import make_blobs
18-
19-
2018
class TestScikitClassification(IntegrationCLITestCase):
2119
async def test_run(self):
2220
self.required_plugins("dffml-model-scikit")
@@ -306,7 +304,6 @@ async def test_run(self):
306304
"training_data=csv",
307305
"-source-filename",
308306
train_file,
309-
"-source-readonly",
310307
)
311308
# Assess accuracy
312309
await CLI.cli(
@@ -318,7 +315,6 @@ async def test_run(self):
318315
*features,
319316
"-sources",
320317
"test_data=csv",
321-
"-source-readonly",
322318
"-source-filename",
323319
test_file,
324320
]
@@ -339,7 +335,6 @@ async def test_run(self):
339335
*features,
340336
"-sources",
341337
"predict_data=csv",
342-
"-source-readonly",
343338
"-source-filename",
344339
predict_file,
345340
)

model/tensorflow/dffml_model_tensorflow/util/__init__.py

Whitespace-only changes.

model/tensorflow/dffml_model_tensorflow/util/config/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)