Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 665b81d

Browse files
John Andersenpdxjohnny
authored andcommitted
run model predict
Signed-off-by: John Andersen <[email protected]>
1 parent a3f2027 commit 665b81d

File tree

13 files changed

+595
-339
lines changed

13 files changed

+595
-339
lines changed

.ci/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function run_plugin() {
5555
# Create the docs
5656
cd "${SRC_ROOT}"
5757
"${PYTHON}" -m pip install -U -e "${SRC_ROOT}[dev]"
58-
"${PYTHON}" -m dffml service dev install
58+
"${PYTHON}" -m dffml service dev install -user
5959
./scripts/docs.sh
6060

6161
# Log skipped tests to file

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727
performance issue.
2828
- MySQL packaging issue.
2929
- Develop service running one off operations correctly json-loads dict types.
30+
- Operations with configs can be run via the development service
3031
### Removed
3132
- CLI command `operations` removed in favor of `dataflow run`
3233
- Duplicate dataflow diagram code from development service

dffml/base.py

Lines changed: 98 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def get_args(t):
2929
from .log import LOGGER
3030

3131

32+
ARGP = ArgumentParser()
33+
34+
3235
class ParseExpandAction(argparse.Action):
3336
def __call__(self, parser, namespace, values, option_string=None):
3437
if not isinstance(values, list):
@@ -106,11 +109,92 @@ def __str__(self):
106109
return repr(self)
107110

108111

112+
def mkarg(field):
113+
arg = Arg(type=field.type)
114+
# HACK For detecting dataclasses._MISSING_TYPE
115+
if "dataclasses._MISSING_TYPE" not in repr(field.default):
116+
arg["default"] = field.default
117+
if field.type == bool:
118+
arg["action"] = "store_true"
119+
elif inspect.isclass(field.type):
120+
if issubclass(field.type, list):
121+
arg["nargs"] = "+"
122+
if not hasattr(field.type, "SINGLETON"):
123+
raise AttributeError(
124+
f"{field.type.__qualname__} missing attribute SINGLETON"
125+
)
126+
arg["action"] = list_action(field.type)
127+
arg["type"] = field.type.SINGLETON
128+
if hasattr(arg["type"], "load"):
129+
# TODO (python3.8) Use Protocol
130+
arg["type"] = arg["type"].load
131+
elif get_origin(field.type) is list:
132+
arg["type"] = get_args(field.type)[0]
133+
arg["nargs"] = "+"
134+
if "help" in field.metadata:
135+
arg["help"] = field.metadata["help"]
136+
return arg
137+
138+
139+
def convert_value(arg, value):
140+
if value is None:
141+
# Return default if not found and available
142+
if "default" in arg:
143+
return arg["default"]
144+
raise MissingConfig
145+
146+
# TODO This is a oversimplification of argparse's nargs
147+
if not "nargs" in arg:
148+
value = value[0]
149+
if "type" in arg:
150+
# TODO This is a oversimplification of argparse's nargs
151+
if "nargs" in arg:
152+
value = list(map(arg["type"], value))
153+
else:
154+
value = arg["type"](value)
155+
if "action" in arg:
156+
if isinstance(arg["action"], str):
157+
# HACK This accesses _pop_action_class from ArgumentParser
158+
# which is prefaced with an underscore indicating it not an API
159+
# we can rely on
160+
arg["action"] = ARGP._pop_action_class(arg)
161+
namespace = ConfigurableParsingNamespace()
162+
action = arg["action"](dest="dest", option_strings="")
163+
action(None, namespace, value)
164+
value = namespace.dest
165+
return value
166+
167+
168+
def is_config_dict(value):
169+
return bool(
170+
"arg" in value
171+
and "config" in value
172+
and isinstance(value["config"], dict)
173+
)
174+
175+
176+
def _fromdict(cls, **kwargs):
177+
for field in dataclasses.fields(cls):
178+
if field.name in kwargs:
179+
value = kwargs[field.name]
180+
config = {}
181+
if is_config_dict(value):
182+
value, config = value["arg"], value["config"]
183+
value = convert_value(mkarg(field), value)
184+
if inspect.isclass(value) and issubclass(value, BaseConfigurable):
185+
value = value.withconfig(
186+
{field.name: {"arg": None, "config": config}}
187+
)
188+
kwargs[field.name] = value
189+
return cls(**kwargs)
190+
191+
109192
def config(cls):
110193
"""
111194
Decorator to create a dataclass
112195
"""
113196
datacls = dataclasses.dataclass(eq=True, frozen=True)(cls)
197+
datacls._fromdict = classmethod(_fromdict)
114198
datacls._replace = lambda self, *args, **kwargs: dataclasses.replace(
115199
self, *args, **kwargs
116200
)
@@ -130,8 +214,6 @@ class BaseConfigurable(abc.ABC):
130214
only parameter to the __init__ of a BaseDataFlowFacilitatorObject.
131215
"""
132216

133-
__argp = ArgumentParser()
134-
135217
def __init__(self, config: BaseConfig) -> None:
136218
"""
137219
BaseConfigurable takes only one argument to __init__,
@@ -225,41 +307,20 @@ def config_get(cls, config, above, *path) -> BaseConfig:
225307
with contextlib.suppress(KeyError):
226308
value = traverse_config_get(config, *no_label_above)
227309

228-
if value is None:
229-
# Return default if not found and available
230-
if "default" in arg:
231-
return arg["default"]
232-
raise MissingConfig(
233-
"%s missing %r from %s"
234-
% (
235-
cls.__qualname__,
236-
label_above[-1],
237-
".".join(label_above[:-1]),
238-
)
310+
try:
311+
return convert_value(arg, value)
312+
except MissingConfig as error:
313+
error.args = (
314+
(
315+
"%s missing %r from %s"
316+
% (
317+
cls.__qualname__,
318+
label_above[-1],
319+
".".join(label_above[:-1]),
320+
)
321+
),
239322
)
240-
241-
if value is None and "default" in arg:
242-
return arg["default"]
243-
# TODO This is a oversimplification of argparse's nargs
244-
if not "nargs" in arg:
245-
value = value[0]
246-
if "type" in arg:
247-
# TODO This is a oversimplification of argparse's nargs
248-
if "nargs" in arg:
249-
value = list(map(arg["type"], value))
250-
else:
251-
value = arg["type"](value)
252-
if "action" in arg:
253-
if isinstance(arg["action"], str):
254-
# HACK This accesses _pop_action_class from ArgumentParser
255-
# which is prefaced with an underscore indicating it not an API
256-
# we can rely on
257-
arg["action"] = cls.__argp._pop_action_class(arg)
258-
namespace = ConfigurableParsingNamespace()
259-
action = arg["action"](dest="dest", option_strings="")
260-
action(None, namespace, value)
261-
value = namespace.dest
262-
return value
323+
raise
263324

264325
@classmethod
265326
def args(cls, args, *above) -> Dict[str, Arg]:
@@ -271,30 +332,7 @@ def args(cls, args, *above) -> Dict[str, Arg]:
271332
f"{cls.__qualname__} requires CONFIG property or implementation of args() classmethod"
272333
)
273334
for field in dataclasses.fields(cls.CONFIG):
274-
arg = Arg(type=field.type)
275-
# HACK For detecting dataclasses._MISSING_TYPE
276-
if "dataclasses._MISSING_TYPE" not in repr(field.default):
277-
arg["default"] = field.default
278-
if field.type == bool:
279-
arg["action"] = "store_true"
280-
elif inspect.isclass(field.type):
281-
if issubclass(field.type, list):
282-
arg["nargs"] = "+"
283-
if not hasattr(field.type, "SINGLETON"):
284-
raise AttributeError(
285-
f"{field.type.__qualname__} missing attribute SINGLETON"
286-
)
287-
arg["action"] = list_action(field.type)
288-
arg["type"] = field.type.SINGLETON
289-
if hasattr(arg["type"], "load"):
290-
# TODO (python3.8) Use Protocol
291-
arg["type"] = arg["type"].load
292-
elif get_origin(field.type) is list:
293-
arg["type"] = get_args(field.type)[0]
294-
arg["nargs"] = "+"
295-
if "help" in field.metadata:
296-
arg["help"] = field.metadata["help"]
297-
cls.config_set(args, above, field.name, arg)
335+
cls.config_set(args, above, field.name, mkarg(field))
298336
return args
299337

300338
@classmethod

dffml/df/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,14 @@ class BaseDataFlowObject(BaseDataFlowFacilitatorObject):
5151

5252
@classmethod
5353
def args(cls, args, *above) -> Dict[str, Arg]:
54+
if hasattr(cls, "CONFIG"):
55+
return super(BaseDataFlowObject, cls).args(args, *above)
5456
return args
5557

5658
@classmethod
5759
def config(cls, config, *above) -> BaseConfig:
60+
if hasattr(cls, "CONFIG"):
61+
return super(BaseDataFlowObject, cls).config(config, *above)
5862
return BaseConfig()
5963

6064

@@ -221,6 +225,9 @@ def __init__(self, config):
221225
config = config_cls(**config)
222226
super().__init__(config)
223227

228+
if config_cls is not None:
229+
Implementation.CONFIG = config_cls
230+
224231
if inspect.isclass(func) and issubclass(
225232
func, OperationImplementationContext
226233
):
@@ -247,9 +254,8 @@ async def run(
247254
if uses_self:
248255
# We can't pass self to functions running in threads
249256
# Its not thread safe!
250-
return await (
251-
func.__get__(self, self.__class__)(**inputs)
252-
)
257+
bound = func.__get__(self, self.__class__)
258+
return await bound(**inputs)
253259
elif inspect.iscoroutinefunction(func):
254260
return await func(**inputs)
255261
else:

dffml/df/memory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,10 @@ async def initialize_dataflow(self, dataflow: DataFlow) -> None:
11591159
operation.name,
11601160
)
11611161
opimp_config = BaseConfig()
1162+
if isinstance(opimp_config, dict) and hasattr(
1163+
getattr(opimp, "CONFIG", False), "_fromdict"
1164+
):
1165+
opimp_config = opimp.CONFIG._fromdict(**opimp_config)
11621166
await self.nctx.instantiate(
11631167
operation, opimp_config, opimp=opimp
11641168
)

dffml/operation/model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Dict, Any
2+
3+
from ..repo import Repo
4+
from ..base import config
5+
from ..model import Model
6+
from ..df.types import Definition
7+
from ..df.base import op
8+
9+
10+
@config
11+
class ModelPredictConfig:
12+
model: Model
13+
msg: str
14+
15+
16+
@op(
17+
inputs={
18+
"features": Definition(
19+
name="repo_features", primitive="Dict[str, Any]"
20+
)
21+
},
22+
outputs={
23+
"prediction": Definition(
24+
name="model_predictions", primitive="Dict[str, Any]"
25+
)
26+
},
27+
config_cls=ModelPredictConfig,
28+
imp_enter={"model": (lambda self: self.config.model)},
29+
ctx_enter={"mctx": (lambda self: self.parent.model())},
30+
)
31+
async def model_predict(self, features: Dict[str, Any]) -> Dict[str, Any]:
32+
async def repos():
33+
yield Repo("", data={"features": features})
34+
35+
async for repo in self.mctx.predict(repos()):
36+
return {
37+
"prediction": {self.config.model.config.predict: repo.prediction()}
38+
}

dffml/service/dev.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,26 @@ async def run_op(self, name, opimp):
209209
error.args = (f"{opimp.op.inputs}: {error.args[0]}",)
210210
raise error
211211

212+
config = {}
213+
extra_config = self.extra_config
214+
215+
for i in range(0, 2):
216+
if "config" in extra_config and len(extra_config["config"]):
217+
extra_config = extra_config["config"]
218+
219+
# TODO(p0) This only goes one level deep. This won't work for
220+
# configs that are multi-leveled
221+
if extra_config:
222+
config = extra_config
223+
224+
dataflow = DataFlow.auto(GetSingle, opimp)
225+
if config:
226+
dataflow.configs[opimp.op.name] = config
227+
212228
# Run the operation in the memory orchestrator
213229
async with MemoryOrchestrator.withconfig({}) as orchestrator:
214230
# Orchestrate the running of these operations
215-
async with orchestrator(DataFlow.auto(GetSingle, opimp)) as octx:
231+
async with orchestrator(dataflow) as octx:
216232
async for ctx, results in octx.run(
217233
[
218234
Input(
@@ -290,7 +306,7 @@ class Install(CMD):
290306
"""
291307

292308
arg_user = Arg(
293-
"-user", "Preform user install", default=True, action="store_false"
309+
"-user", "Preform user install", default=False, action="store_true"
294310
)
295311

296312
async def run(self):
@@ -307,12 +323,7 @@ async def run(self):
307323
)
308324
)
309325
self.logger.info("Installing %r in development mode", packages)
310-
cmd = [
311-
sys.executable,
312-
"-m",
313-
"pip",
314-
"install",
315-
]
326+
cmd = [sys.executable, "-m", "pip", "install"]
316327
if self.user:
317328
# --user sometimes fails
318329
local_path = Path("~", ".local").expanduser().absolute()

tests/integration/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)