Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 00b693a

Browse files
John Andersensakshamarora1
authored andcommitted
cli: Fix numpy int* and float* json output
Fixes: #261 Co-authored-by: Saksham Arora <[email protected]> Signed-off-by: John Andersen <[email protected]>
1 parent 7a64def commit 00b693a

File tree

3 files changed

+147
-11
lines changed

3 files changed

+147
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828
- MySQL packaging issue.
2929
- Develop service running one off operations correctly json-loads dict types.
3030
- Operations with configs can be run via the development service
31+
- JSON dumping numpy int\* and float\* caused crash on dump.
3132
### Removed
3233
- CLI command `operations` removed in favor of `dataflow run`
3334
- Duplicate dataflow diagram code from development service

dffml/util/cli/cmd.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class JSONEncoder(json.JSONEncoder):
3232
"""
3333

3434
def default(self, obj):
35+
typename_lower = str(type(obj)).lower()
3536
if isinstance(obj, Repo):
3637
return obj.dict()
3738
elif isinstance(obj, Feature):
@@ -40,6 +41,11 @@ def default(self, obj):
4041
return str(obj.value)
4142
elif isinstance(obj, type):
4243
return str(obj.__qualname__)
44+
elif "numpy." in typename_lower:
45+
if ".int" in typename_lower or ".uint" in typename_lower:
46+
return int(obj)
47+
elif typename_lower.startswith("float"):
48+
return float(obj)
4349
elif str(obj).startswith("typing."):
4450
return str(obj).split(".")[-1]
4551
return json.JSONEncoder.default(self, obj)
@@ -169,17 +175,8 @@ def sanitize_args(cls, args):
169175
return args
170176

171177
@classmethod
172-
def main(cls, loop=asyncio.get_event_loop(), argv=sys.argv):
173-
"""
174-
Runs cli commands in asyncio loop and outputs in appropriate format
175-
"""
176-
result = None
177-
try:
178-
result = loop.run_until_complete(cls.cli(*argv[1:]))
179-
except KeyboardInterrupt: # pragma: no cover
180-
pass # pragma: no cover
181-
loop.run_until_complete(loop.shutdown_asyncgens())
182-
loop.close()
178+
async def _main(cls, *args):
179+
result = await cls.cli(*args)
183180
if not result is None and result is not DisplayHelp:
184181
json.dump(
185182
result,
@@ -191,6 +188,21 @@ def main(cls, loop=asyncio.get_event_loop(), argv=sys.argv):
191188
)
192189
print()
193190

191+
@classmethod
192+
def main(cls, loop=None, argv=sys.argv):
193+
"""
194+
Runs cli commands in asyncio loop and outputs in appropriate format
195+
"""
196+
if loop is None:
197+
loop = asyncio.get_event_loop()
198+
result = None
199+
try:
200+
result = loop.run_until_complete(cls._main(*argv[1:]))
201+
except KeyboardInterrupt: # pragma: no cover
202+
pass # pragma: no cover
203+
loop.run_until_complete(loop.shutdown_asyncgens())
204+
loop.close()
205+
194206
@classmethod
195207
def args(cls, args, *above) -> Dict[str, Any]:
196208
"""

tests/integration/test_models.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
This file contains integration tests. We use the CLI to exercise functionality of
3+
various DFFML classes and constructs.
4+
"""
5+
import re
6+
import os
7+
import io
8+
import json
9+
import inspect
10+
import pathlib
11+
import asyncio
12+
import contextlib
13+
import unittest.mock
14+
from typing import Dict, Any
15+
16+
from dffml.repo import Repo
17+
from dffml.base import config
18+
from dffml.df.types import Definition, Operation, DataFlow, Input
19+
from dffml.df.base import op
20+
from dffml.cli.cli import CLI
21+
from dffml.model.model import Model
22+
from dffml.service.dev import Develop
23+
from dffml.util.packaging import is_develop
24+
from dffml.util.entrypoint import load
25+
from dffml.config.config import BaseConfigLoader
26+
from dffml.util.asynctestcase import AsyncTestCase
27+
28+
from .common import IntegrationCLITestCase
29+
30+
31+
class TestScikitClassification(IntegrationCLITestCase):
32+
async def test_run(self):
33+
self.required_plugins("dffml-model-scikit")
34+
# Create the training data
35+
train_filename = self.mktempfile() + ".csv"
36+
pathlib.Path(train_filename).write_text(
37+
inspect.cleandoc(
38+
"""
39+
Clump_Thickness,Uniformity_of_Cell_Size,Uniformity_of_Cell_Shape,Marginal_Adhesion,Single_Epithelial_Cell_Size,Bare_Nuclei,Bland_Chromatin,Normal_Nucleoli,Mitoses,Class
40+
3,4,5,2,6,8,4,1,1,4
41+
1,1,1,1,3,2,2,1,1,2
42+
3,1,1,3,8,1,5,8,1,2
43+
8,8,7,4,10,10,7,8,7,4
44+
"""
45+
)
46+
+ "\n"
47+
)
48+
# Create the test data
49+
test_filename = self.mktempfile() + ".csv"
50+
pathlib.Path(test_filename).write_text(
51+
inspect.cleandoc(
52+
"""
53+
Clump_Thickness,Uniformity_of_Cell_Size,Uniformity_of_Cell_Shape,Marginal_Adhesion,Single_Epithelial_Cell_Size,Bare_Nuclei,Bland_Chromatin,Normal_Nucleoli,Mitoses,Class
54+
1,1,1,1,1,1,3,1,1,2
55+
7,2,4,1,6,10,5,4,3,4
56+
"""
57+
)
58+
+ "\n"
59+
)
60+
# Create the prediction data
61+
predict_filename = self.mktempfile() + ".csv"
62+
pathlib.Path(predict_filename).write_text(
63+
inspect.cleandoc(
64+
"""
65+
Clump_Thickness,Uniformity_of_Cell_Size,Uniformity_of_Cell_Shape,Marginal_Adhesion,Single_Epithelial_Cell_Size,Bare_Nuclei,Bland_Chromatin,Normal_Nucleoli,Mitoses,Class
66+
5,3,3,3,6,10,3,1,1
67+
"""
68+
)
69+
+ "\n"
70+
)
71+
# Features
72+
features = "-model-features def:Clump_Thickness:int:1 def:Uniformity_of_Cell_Size:int:1 def:Uniformity_of_Cell_Shape:int:1 def:Marginal_Adhesion:int:1 def:Single_Epithelial_Cell_Size:int:1 def:Bare_Nuclei:int:1 def:Bland_Chromatin:int:1 def:Normal_Nucleoli:int:1 def:Mitoses:int:1".split()
73+
# Train the model
74+
await CLI.cli(
75+
"train",
76+
"-model",
77+
"scikitsvc",
78+
*features,
79+
"-model-predict",
80+
"Class",
81+
"-sources",
82+
"training_data=csv",
83+
"-source-filename",
84+
train_filename,
85+
)
86+
# Assess accuracy
87+
await CLI.cli(
88+
"accuracy",
89+
"-model",
90+
"scikitsvc",
91+
*features,
92+
"-model-predict",
93+
"Class",
94+
"-sources",
95+
"test_data=csv",
96+
"-source-filename",
97+
test_filename,
98+
)
99+
# Ensure JSON output works as expected (#261)
100+
with contextlib.redirect_stdout(self.stdout):
101+
# Make prediction
102+
await CLI._main(
103+
"predict",
104+
"all",
105+
"-model",
106+
"scikitsvc",
107+
*features,
108+
"-model-predict",
109+
"Class",
110+
"-sources",
111+
"predict_data=csv",
112+
"-source-filename",
113+
predict_filename,
114+
)
115+
results = json.loads(self.stdout.getvalue())
116+
self.assertTrue(isinstance(results, list))
117+
self.assertTrue(results)
118+
results = results[0]
119+
self.assertIn("prediction", results)
120+
results = results["prediction"]
121+
self.assertIn("value", results)
122+
results = results["value"]
123+
self.assertEqual(4, results)

0 commit comments

Comments
 (0)