Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit f3b0f81

Browse files
authored
source: db: Source using the database abstraction
Fixes: #402
1 parent e1cd112 commit f3b0f81

File tree

10 files changed

+211
-8
lines changed

10 files changed

+211
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- Docstrings and doctestable examples to `record.py`.
1212
- Inputs can be validated using operations
1313
- `validate` parameter in `Input` takes `Operation.instance_name`
14+
- New db source can utilize any database that inherits from `BaseDatabase`
1415
- Logistic Regression with SAG optimizer
1516
- Test tensorflow DNNEstimator documentation exaples in CI
1617
- Add python code for tensorflow DNNEstimator

dffml/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ class BaseConfigurable(metaclass=BaseConfigurableMetaClass):
336336
only parameter to the __init__ of a BaseDataFlowFacilitatorObject.
337337
"""
338338

339-
def __init__(self, config: BaseConfig) -> None:
339+
def __init__(self, config: Type[BaseConfig]) -> None:
340340
"""
341341
BaseConfigurable takes only one argument to __init__,
342342
its config, which should inherit from BaseConfig. It shall be a object
@@ -538,7 +538,7 @@ class BaseDataFlowFacilitatorObject(
538538
>>> asyncio.run(main())
539539
"""
540540

541-
def __init__(self, config: BaseConfig) -> None:
541+
def __init__(self, config: Type[BaseConfig]) -> None:
542542
BaseConfigurable.__init__(self, config)
543543
# TODO figure out how to call these in __new__
544544
self.__ensure_property("CONTEXT")

dffml/db/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
Condition = collections.namedtuple(
12-
"Condtion", ["column", "operation", "value"]
12+
"Condition", ["column", "operation", "value"]
1313
)
1414
Conditions = Union[List[List[Condition]], List[List[Tuple[str]]]]
1515

dffml/source/db.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import collections
2+
from typing import Type, AsyncIterator, List
3+
4+
from dffml.base import config, BaseConfig
5+
from dffml.db.base import BaseDatabase, Condition
6+
from dffml.record import Record
7+
from dffml.source.source import BaseSource, BaseSourceContext
8+
from dffml.util.entrypoint import entrypoint
9+
10+
11+
@config
12+
class DbSourceConfig(BaseConfig):
13+
db: BaseDatabase
14+
table_name: str
15+
model_columns: List[str]
16+
17+
18+
class DbSourceContext(BaseSourceContext):
19+
async def update(self, record: Record):
20+
model_columns = self.parent.config.model_columns
21+
key_value_pairs = collections.OrderedDict()
22+
for key in model_columns:
23+
if key.startswith("feature_"):
24+
modified_key = key.replace("feature_", "")
25+
key_value_pairs[key] = record.data.features[modified_key]
26+
elif "_value" in key:
27+
target = key.replace("_value", "")
28+
if record.data.prediction:
29+
key_value_pairs[key] = record.data.prediction[target][
30+
"value"
31+
]
32+
else:
33+
key_value_pairs[key] = "undetermined"
34+
elif "_confidence" in key:
35+
target = key.replace("_confidence", "")
36+
if record.data.prediction:
37+
key_value_pairs[key] = record.data.prediction[target][
38+
"confidence"
39+
]
40+
else:
41+
key_value_pairs[key] = 1
42+
else:
43+
key_value_pairs[key] = record.data.__dict__[key]
44+
async with self.parent.db() as db_ctx:
45+
await db_ctx.insert_or_update(
46+
self.parent.config.table_name, key_value_pairs
47+
)
48+
self.logger.debug("update: %s", await self.record(record.key))
49+
50+
async def records(self) -> AsyncIterator[Record]:
51+
async with self.parent.db() as db_ctx:
52+
async for result in db_ctx.lookup(self.parent.config.table_name):
53+
yield self.convert_to_record(result)
54+
55+
def convert_to_record(self, result):
56+
modified_record = {
57+
"key": "",
58+
"data": {"features": {}, "prediction": {}},
59+
}
60+
for key, value in result.items():
61+
if key.startswith("feature_"):
62+
modified_record["data"]["features"][
63+
key.replace("feature_", "")
64+
] = value
65+
elif ("_value" in key) or ("_confidence" in key):
66+
target = key.replace("_value", "").replace("_confidence", "")
67+
modified_record["data"]["prediction"][target] = {
68+
"value": result[target + "_value"],
69+
"confidence": result[target + "_confidence"],
70+
}
71+
else:
72+
modified_record[key] = value
73+
return Record(modified_record["key"], data=modified_record["data"])
74+
75+
async def record(self, key: str):
76+
record = Record(key)
77+
async with self.parent.db() as db_ctx:
78+
try:
79+
row = await db_ctx.lookup(
80+
self.parent.config.table_name,
81+
cols=None, # None turns into *. We want all rows
82+
conditions=[[Condition("key", "=", key)]],
83+
).__anext__()
84+
except StopAsyncIteration:
85+
# This would happen if there is no matching row, so the async generator reached the end
86+
return record
87+
88+
if row is not None:
89+
features = {}
90+
predictions = {}
91+
for key, value in row.items():
92+
if key.startswith("feature_"):
93+
features[key.replace("feature_", "")] = value
94+
elif "_value" in key:
95+
target = key.replace("_value", "")
96+
predictions[target] = {
97+
"value": row[target + "_value"],
98+
"confidence": row[target + "_confidence"],
99+
}
100+
record.merge(
101+
Record(
102+
row["key"],
103+
data={"features": features, "prediction": predictions},
104+
)
105+
)
106+
return record
107+
108+
109+
@entrypoint("db")
110+
class DbSource(BaseSource):
111+
CONFIG = DbSourceConfig
112+
CONTEXT = DbSourceContext
113+
114+
def __init__(self, cfg: Type[BaseConfig]) -> None:
115+
super().__init__(cfg)
116+
117+
async def __aenter__(self) -> "DbSource":
118+
self.db = await self.config.db.__aenter__()
119+
return self
120+
121+
async def __aexit__(self, exc_type, exc_value, traceback):
122+
await self.db.__aexit__(exc_type, exc_value, traceback)

dffml/util/entrypoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def add_entry_point_label(cls):
8181
def base_entry_point(entrypoint, *args):
8282
"""
8383
Any class which subclasses from Entrypoint needs this decorator applied to
84-
it. The decorator sets the ENTRYPOINT and ENTRY_POINT_NAME proprieties on
84+
it. The decorator sets the ENTRYPOINT and ENTRY_POINT_NAME properties on
8585
the class.
8686
8787
This allows the load() classmethod to be called to load subclasses of the

dffml/util/testing/source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class SourceTest(abc.ABC):
2424
"""
2525

2626
@abc.abstractmethod
27-
async def setUpSource(self, fileobj):
27+
async def setUpSource(self):
2828
pass # pragma: no cover
2929

3030
async def test_update(self):

docs/contributing/codebase.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Codebase Layout And Notes
66
Plugins
77
-------
88

9-
DFFML is plugin based. This means that there the source code for the main
9+
DFFML is plugin based. This means that the source code for the main
1010
package ``dffml``, is separate from the source code for many of the things you
1111
might want to use in conjunction with it. For example, if you wanted to use the
1212
machine learning models based on scikit, you'd install ``dffml-model-scikit``.

scripts/docs/care

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
dffml.operation dffml dffml_feature_git dffml_operations_binsec dffml_feature_auth
22
dffml.model dffml dffml_model_tensorflow dffml_model_tensorflow_hub dffml_model_transformers dffml_model_scratch dffml_model_scikit
33
dffml.source dffml dffml_source_mysql
4-
dffml.service.cli dffml dffml_service_http
4+
dffml.service.cli dffml dffml_service_http

source/mysql/tests/test_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def setUpClass(cls):
4848
ca=cls.ca,
4949
)
5050
# Make it so that when the client tries to connect to mysql.unittest the
51-
# address it get's back is the one for the container
51+
# address it gets back is the one for the container
5252
cls.exit_stack.enter_context(
5353
patch(
5454
"socket.getaddrinfo",

tests/source/test_db.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
import sqlite3
3+
import tempfile
4+
from typing import Dict
5+
6+
from dffml.db.sqlite import SqliteDatabaseConfig, SqliteDatabase
7+
from dffml.util.asynctestcase import AsyncTestCase
8+
from dffml.util.testing.source import SourceTest
9+
from dffml.source.db import DbSource, DbSourceConfig
10+
11+
12+
class TestDbSource(AsyncTestCase, SourceTest):
13+
db_config: SqliteDatabaseConfig
14+
source_config: DbSourceConfig
15+
database_name: str
16+
cols: Dict[str, str]
17+
18+
SQL_TEARDOWN = """
19+
DROP TABLE IF EXISTS `TestTable`;
20+
"""
21+
SQL_SETUP = """
22+
CREATE TABLE `TestTable` (
23+
`key` varchar(100) NOT NULL,
24+
`feature_PetalLength` float DEFAULT NULL,
25+
`feature_PetalWidth` float DEFAULT NULL,
26+
`feature_SepalLength` float DEFAULT NULL,
27+
`feature_SepalWidth` float DEFAULT NULL,
28+
`target_name_confidence` float DEFAULT NULL,
29+
`target_name_value` varchar(100) DEFAULT NULL,
30+
PRIMARY KEY (`key`)
31+
);
32+
"""
33+
34+
@classmethod
35+
def setUpClass(cls):
36+
# SQL table info
37+
cls.table_name = "TestTable"
38+
cls.cols = {
39+
"key": "varchar(100) NOT NULL PRIMARY KEY",
40+
"feature_PetalLength": "float DEFAULT NULL",
41+
"feature_PetalWidth": "float DEFAULT NULL",
42+
"feature_SepalLength": "float DEFAULT NULL",
43+
"feature_SepalWidth": "float DEFAULT NULL",
44+
"target_name_confidence": "float DEFAULT NULL",
45+
"target_name_value": "varchar(100) DEFAULT NULL",
46+
}
47+
48+
# Sqlite db file
49+
file, cls.database_name = tempfile.mkstemp(suffix=".db")
50+
os.close(file)
51+
52+
# Sqlite config
53+
cls.db_config = SqliteDatabaseConfig(cls.database_name)
54+
55+
# DbSource config
56+
cls.source_config = DbSourceConfig(
57+
db=SqliteDatabase(cls.db_config),
58+
table_name=cls.table_name,
59+
model_columns="key feature_PetalLength feature_PetalWidth feature_SepalLength feature_SepalWidth target_name_confidence target_name_value".split(),
60+
)
61+
62+
# Setup connection to reset state (different from the connection used in the tests)
63+
conn = sqlite3.connect(cls.database_name)
64+
db_cursor = conn.cursor()
65+
db_cursor.execute(cls.SQL_TEARDOWN)
66+
db_cursor.execute(cls.SQL_SETUP)
67+
conn.commit()
68+
db_cursor.close()
69+
conn.close()
70+
71+
@classmethod
72+
def tearDownClass(cls):
73+
os.remove(cls.database_name)
74+
75+
async def setUpSource(self):
76+
return DbSource(self.source_config)
77+
78+
79+
# TODO: Potential shortcoming: Is there a way to call this source from the CLI and pass the db object (e.g. SqliteDatabase)?
80+
# dffml list records -sources primary=dbsource -source-db_implementation sqlite -source-table_name testTable -source-db ??? -source-model_columns "key feature_PetalLength feature_PetalWidth feature_SepalLength feature_SepalWidth target_name_confidence target_name_value"

0 commit comments

Comments
 (0)