Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit ca96c32

Browse files
committed
tests: integration: CSV source string src_urls
Signed-off-by: John Andersen <[email protected]>
1 parent fe54147 commit ca96c32

File tree

2 files changed

+125
-3
lines changed

2 files changed

+125
-3
lines changed

tests/integration/common.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import asyncio
1919
import contextlib
2020
import unittest.mock
21-
from typing import Dict, Any
21+
from typing import Dict, Any, Optional
2222

2323
from dffml.repo import Repo
2424
from dffml.base import config
@@ -78,5 +78,12 @@ def required_plugins(self, *args):
7878
f"Required plugins: {', '.join(args)} must be installed in development mode"
7979
)
8080

81-
def mktempfile(self):
82-
return self._stack.enter_context(non_existant_tempfile())
81+
def mktempfile(
82+
self, suffix: Optional[str] = None, text: Optional[str] = None
83+
):
84+
filename = self._stack.enter_context(non_existant_tempfile())
85+
if suffix:
86+
filename = filename + suffix
87+
if text:
88+
pathlib.Path(filename).write_text(inspect.cleandoc(text) + "\n")
89+
return filename

tests/integration/test_sources.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
This file contains integration tests. We use the CLI to exercise functionality of
3+
various DFFML classes and constructs.
4+
"""
5+
import re
6+
import os
7+
import io
8+
import json
9+
import inspect
10+
import pathlib
11+
import asyncio
12+
import contextlib
13+
import unittest.mock
14+
from typing import Dict, Any
15+
16+
from dffml.repo import Repo
17+
from dffml.base import config
18+
from dffml.df.types import Definition, Operation, DataFlow, Input
19+
from dffml.df.base import op
20+
from dffml.cli.cli import CLI
21+
from dffml.model.model import Model
22+
from dffml.service.dev import Develop
23+
from dffml.util.packaging import is_develop
24+
from dffml.util.entrypoint import load
25+
from dffml.config.config import BaseConfigLoader
26+
from dffml.util.asynctestcase import AsyncTestCase
27+
28+
from .common import IntegrationCLITestCase
29+
30+
31+
class TestCSV(IntegrationCLITestCase):
32+
async def test_string_src_urls(self):
33+
# Test for issue #207
34+
self.required_plugins("dffml-model-scikit")
35+
# Create the training data
36+
train_filename = self.mktempfile(
37+
suffix=".csv",
38+
text="""
39+
Years,Expertise,Trust,Salary
40+
0,1,0.2,10
41+
1,3,0.4,20
42+
2,5,0.6,30
43+
3,7,0.8,40
44+
""",
45+
)
46+
# Create the test data
47+
test_filename = self.mktempfile(
48+
suffix=".csv",
49+
text="""
50+
Years,Expertise,Trust,Salary
51+
4,9,1.0,50
52+
5,11,1.2,60
53+
""",
54+
)
55+
# Create the prediction data
56+
predict_filename = self.mktempfile(
57+
suffix=".csv",
58+
text="""
59+
Years,Expertise,Trust
60+
6,13,1.4
61+
""",
62+
)
63+
# Features
64+
features = "-model-features def:Years:int:1 def:Expertise:int:1 def:Trust:float:1".split()
65+
# Train the model
66+
await CLI.cli(
67+
"train",
68+
"-model",
69+
"scikitlr",
70+
*features,
71+
"-model-predict",
72+
"Salary",
73+
"-sources",
74+
"training_data=csv",
75+
"-source-filename",
76+
train_filename,
77+
)
78+
# Assess accuracy
79+
await CLI.cli(
80+
"accuracy",
81+
"-model",
82+
"scikitlr",
83+
*features,
84+
"-model-predict",
85+
"Salary",
86+
"-sources",
87+
"test_data=csv",
88+
"-source-filename",
89+
test_filename,
90+
)
91+
# Ensure JSON output works as expected (#261)
92+
with contextlib.redirect_stdout(self.stdout):
93+
# Make prediction
94+
await CLI._main(
95+
"predict",
96+
"all",
97+
"-model",
98+
"scikitlr",
99+
*features,
100+
"-model-predict",
101+
"Salary",
102+
"-sources",
103+
"predict_data=csv",
104+
"-source-filename",
105+
predict_filename,
106+
)
107+
results = json.loads(self.stdout.getvalue())
108+
self.assertTrue(isinstance(results, list))
109+
self.assertTrue(results)
110+
results = results[0]
111+
self.assertIn("src_url", results)
112+
self.assertEqual("0", results["src_url"])
113+
self.assertIn("prediction", results)
114+
self.assertIn("value", results["prediction"])
115+
self.assertEqual(70.0, results["prediction"]["value"])

0 commit comments

Comments
 (0)