Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit b9e9758

Browse files
authored
df: memory: Input validation using operations
Signed-off-by: John Andersen <[email protected]>
1 parent 755a000 commit b9e9758

File tree

8 files changed

+175
-19
lines changed

8 files changed

+175
-19
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [Unreleased]
88
### Added
99
- Docstrings and doctestable examples to `record.py`.
10+
- Inputs can be validated using operations
11+
- `validate` parameter in `Input` takes `Operation.instance_name`
1012
### Fixed
1113
- New model tutorial mentions file paths that should be edited.
1214

dffml/df/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ class NotOpImp(Exception):
1616

1717
class InputValidationError(Exception):
1818
pass
19+
20+
21+
class ValidatorMissing(Exception):
22+
pass

dffml/df/memory.py

Lines changed: 103 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
Set,
2121
)
2222

23-
from .exceptions import ContextNotPresent, DefinitionNotInContext
23+
from .exceptions import (
24+
ContextNotPresent,
25+
DefinitionNotInContext,
26+
ValidatorMissing,
27+
)
2428
from .types import Input, Parameter, Definition, Operation, Stage, DataFlow
2529
from .base import (
2630
OperationException,
@@ -122,6 +126,26 @@ async def inputs(self) -> AsyncIterator[Input]:
122126
for item in self.__inputs:
123127
yield item
124128

129+
def remove_input(self, item: Input):
130+
for x in self.__inputs[:]:
131+
if x.uid == item.uid:
132+
self.__inputs.remove(x)
133+
break
134+
135+
def remove_unvalidated_inputs(self) -> "MemoryInputSet":
136+
"""
137+
Removes `unvalidated` inputs from internal list and returns the same.
138+
"""
139+
unvalidated_inputs = []
140+
for x in self.__inputs[:]:
141+
if not x.validated:
142+
unvalidated_inputs.append(x)
143+
self.__inputs.remove(x)
144+
unvalidated_input_set = MemoryInputSet(
145+
MemoryInputSetConfig(ctx=self.ctx, inputs=unvalidated_inputs)
146+
)
147+
return unvalidated_input_set
148+
125149

126150
class MemoryParameterSetConfig(NamedTuple):
127151
ctx: BaseInputSetContext
@@ -249,15 +273,19 @@ async def add(self, input_set: BaseInputSet):
249273
handle_string = handle.as_string()
250274
# TODO These ctx.add calls should probably happen after inputs are in
251275
# self.ctxhd
276+
277+
# remove unvalidated inputs
278+
unvalidated_input_set = input_set.remove_unvalidated_inputs()
279+
252280
# If the context for this input set does not exist create a
253281
# NotificationSet for it to notify the orchestrator
254282
if not handle_string in self.input_notification_set:
255283
self.input_notification_set[handle_string] = NotificationSet()
256284
async with self.ctx_notification_set() as ctx:
257-
await ctx.add(input_set.ctx)
285+
await ctx.add((None, input_set.ctx))
258286
# Add the input set to the incoming inputs
259287
async with self.input_notification_set[handle_string]() as ctx:
260-
await ctx.add(input_set)
288+
await ctx.add((unvalidated_input_set, input_set))
261289
# Associate inputs with their context handle grouped by definition
262290
async with self.ctxhd_lock:
263291
# Create dict for handle_string if not present
@@ -921,6 +949,7 @@ async def run_dispatch(
921949
octx: BaseOrchestratorContext,
922950
operation: Operation,
923951
parameter_set: BaseParameterSet,
952+
set_valid: bool = True,
924953
):
925954
"""
926955
Run an operation in the background and add its outputs to the input
@@ -952,14 +981,14 @@ async def run_dispatch(
952981
if not key in expand:
953982
output = [output]
954983
for value in output:
955-
inputs.append(
956-
Input(
957-
value=value,
958-
definition=operation.outputs[key],
959-
parents=parents,
960-
origin=(operation.instance_name, key),
961-
)
984+
new_input = Input(
985+
value=value,
986+
definition=operation.outputs[key],
987+
parents=parents,
988+
origin=(operation.instance_name, key),
962989
)
990+
new_input.validated = set_valid
991+
inputs.append(new_input)
963992
except KeyError as error:
964993
raise KeyError(
965994
"Value %s missing from output:definition mapping %s(%s)"
@@ -1020,6 +1049,38 @@ async def operations_parameter_set_pairs(
10201049
):
10211050
yield operation, parameter_set
10221051

1052+
async def validator_target_set_pairs(
1053+
self,
1054+
octx: BaseOperationNetworkContext,
1055+
rctx: BaseRedundancyCheckerContext,
1056+
ctx: BaseInputSetContext,
1057+
dataflow: DataFlow,
1058+
unvalidated_input_set: BaseInputSet,
1059+
):
1060+
async for unvalidated_input in unvalidated_input_set.inputs():
1061+
validator_instance_name = unvalidated_input.definition.validate
1062+
validator = dataflow.validators.get(validator_instance_name, None)
1063+
if validator is None:
1064+
raise ValidatorMissing(
1065+
"Validator with instance_name {validator_instance_name} not found"
1066+
)
1067+
# There is only one `input` in `validators`
1068+
input_name, input_definition = list(validator.inputs.items())[0]
1069+
parameter = Parameter(
1070+
key=input_name,
1071+
value=unvalidated_input.value,
1072+
origin=unvalidated_input,
1073+
definition=input_definition,
1074+
)
1075+
parameter_set = MemoryParameterSet(
1076+
MemoryParameterSetConfig(ctx=ctx, parameters=[parameter])
1077+
)
1078+
async for parameter_set, exists in rctx.exists(
1079+
validator, parameter_set
1080+
):
1081+
if not exists:
1082+
yield validator, parameter_set
1083+
10231084

10241085
@entrypoint("memory")
10251086
class MemoryOperationImplementationNetwork(
@@ -1382,17 +1443,44 @@ async def run_operations_for_ctx(
13821443
task.print_stack(file=output)
13831444
self.logger.error("%s", output.getvalue().rstrip())
13841445
output.close()
1446+
13851447
elif task is input_set_enters_network:
13861448
(
13871449
more,
13881450
new_input_sets,
13891451
) = input_set_enters_network.result()
1390-
for new_input_set in new_input_sets:
1452+
for (
1453+
unvalidated_input_set,
1454+
new_input_set,
1455+
) in new_input_sets:
1456+
async for operation, parameter_set in self.nctx.validator_target_set_pairs(
1457+
self.octx,
1458+
self.rctx,
1459+
ctx,
1460+
self.config.dataflow,
1461+
unvalidated_input_set,
1462+
):
1463+
await self.rctx.add(
1464+
operation, parameter_set
1465+
) # is this required here?
1466+
dispatch_operation = await self.nctx.dispatch(
1467+
self, operation, parameter_set
1468+
)
1469+
dispatch_operation.operation = operation
1470+
dispatch_operation.parameter_set = (
1471+
parameter_set
1472+
)
1473+
tasks.add(dispatch_operation)
1474+
self.logger.debug(
1475+
"[%s]: dispatch operation: %s",
1476+
ctx_str,
1477+
operation.instance_name,
1478+
)
13911479
# forward inputs to subflow
13921480
await self.forward_inputs_to_subflow(
13931481
[x async for x in new_input_set.inputs()]
13941482
)
1395-
# Identify which operations have complete contextually
1483+
# Identify which operations have completed contextually
13961484
# appropriate input sets which haven't been run yet
13971485
async for operation, parameter_set in self.nctx.operations_parameter_set_pairs(
13981486
self.ictx,
@@ -1402,6 +1490,9 @@ async def run_operations_for_ctx(
14021490
self.config.dataflow,
14031491
new_input_set=new_input_set,
14041492
):
1493+
# Validation operations shouldn't be run here
1494+
if operation.validator:
1495+
continue
14051496
# Add inputs and operation to redundancy checker before
14061497
# dispatch
14071498
await self.rctx.add(operation, parameter_set)

dffml/df/types.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class Operation(NamedTuple, Entrypoint):
122122
conditions: Optional[List[Definition]] = []
123123
expand: Optional[List[str]] = []
124124
instance_name: Optional[str] = None
125+
validator: bool = False
125126

126127
def export(self):
127128
exported = {
@@ -270,11 +271,13 @@ def __init__(
270271
definition: Definition,
271272
parents: Optional[List["Input"]] = None,
272273
origin: Optional[Union[str, Tuple[Operation, str]]] = "seed",
274+
validated: bool = True,
273275
*,
274276
uid: Optional[str] = "",
275277
):
276278
# TODO Add optional parameter Input.target which specifies the operation
277279
# instance name this Input is intended for.
280+
self.validated = validated
278281
if parents is None:
279282
parents = []
280283
if definition.spec is not None:
@@ -288,7 +291,11 @@ def __init__(
288291
elif isinstance(value, dict):
289292
value = definition.spec(**value)
290293
if definition.validate is not None:
291-
value = definition.validate(value)
294+
if callable(definition.validate):
295+
value = definition.validate(value)
296+
# if validate is a string (operation.instance_name) set `not validated`
297+
elif isinstance(definition.validate, str):
298+
self.validated = False
292299
self.value = value
293300
self.definition = definition
294301
self.parents = parents
@@ -424,6 +431,8 @@ def __post_init__(self):
424431
self.by_origin = {}
425432
if self.implementations is None:
426433
self.implementations = {}
434+
self.validators = {} # Maps `validator` ops instance_name to op
435+
427436
# Allow callers to pass in functions decorated with op. Iterate over the
428437
# given operations and replace any which have been decorated with their
429438
# operation. Add the implementation to our dict of implementations.
@@ -451,9 +460,10 @@ def __post_init__(self):
451460
self.operations[instance_name] = operation
452461
value = operation
453462
# Make sure every operation has the correct instance name
454-
self.operations[instance_name] = value._replace(
455-
instance_name=instance_name
456-
)
463+
value = value._replace(instance_name=instance_name)
464+
self.operations[instance_name] = value
465+
if value.validator:
466+
self.validators[instance_name] = value
457467
# Grab all definitions from operations
458468
operations = list(self.operations.values())
459469
definitions = list(

examples/shouldi/tests/test_npm_audit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class TestRunNPM_AuditOp(AsyncTestCase):
2121
"36b3ce51780ee6ea8dcec266c9d09e3a00198868ba1b041569950b82cf45884da0c47ec354dd8514022169849dfe8b7c",
2222
)
2323
async def test_run(self, npm_audit, javascript_algo):
24-
with prepend_to_path(npm_audit / "bin",):
24+
with prepend_to_path(npm_audit / "bin"):
2525
results = await run_npm_audit(
2626
str(
2727
javascript_algo

model/scikit/dffml_model_scikit/scikit_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def applicable_features(self, features):
226226
field(
227227
"Directory where state should be saved",
228228
default=pathlib.Path(
229-
"~", ".cache", "dffml", f"scikit-{entry_point_name}",
229+
"~", ".cache", "dffml", f"scikit-{entry_point_name}"
230230
),
231231
),
232232
),

scripts/docs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def gen_docs(
212212

213213
def fake_getpwuid(uid):
214214
return pwd.struct_passwd(
215-
("user", "x", uid, uid, "", "/home/user", "/bin/bash",)
215+
("user", "x", uid, uid, "", "/home/user", "/bin/bash")
216216
)
217217

218218

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def pie_validation(x):
1919
ShapeName = Definition(
2020
name="shape_name", primitive="str", validate=lambda x: x.upper()
2121
)
22+
SHOUTIN = Definition(
23+
name="shout_in", primitive="str", validate="validate_shout_instance"
24+
)
25+
SHOUTOUT = Definition(name="shout_out", primitive="str")
2226

2327

2428
@op(
@@ -35,6 +39,20 @@ async def get_circle(name: str, radius: float, pie: float):
3539
}
3640

3741

42+
@op(
43+
inputs={"shout_in": SHOUTIN},
44+
outputs={"shout_in_validated": SHOUTIN},
45+
validator=True,
46+
)
47+
def validate_shouts(shout_in):
48+
return {"shout_in_validated": shout_in + "_validated"}
49+
50+
51+
@op(inputs={"shout_in": SHOUTIN}, outputs={"shout_out": SHOUTOUT})
52+
def echo_shout(shout_in):
53+
return {"shout_out": shout_in}
54+
55+
3856
class TestDefintion(AsyncTestCase):
3957
async def setUp(self):
4058
self.dataflow = DataFlow(
@@ -80,3 +98,34 @@ async def test_validation_error(self):
8098
]
8199
}
82100
pass
101+
102+
async def test_vaildation_by_op(self):
103+
test_dataflow = DataFlow(
104+
operations={
105+
"validate_shout_instance": validate_shouts.op,
106+
"echo_shout": echo_shout.op,
107+
"get_single": GetSingle.imp.op,
108+
},
109+
seed=[
110+
Input(
111+
value=[echo_shout.op.outputs["shout_out"].name],
112+
definition=GetSingle.op.inputs["spec"],
113+
)
114+
],
115+
implementations={
116+
validate_shouts.op.name: validate_shouts.imp,
117+
echo_shout.op.name: echo_shout.imp,
118+
},
119+
)
120+
test_inputs = {
121+
"TestShoutOut": [
122+
Input(value="validation_status:", definition=SHOUTIN)
123+
]
124+
}
125+
async with MemoryOrchestrator.withconfig({}) as orchestrator:
126+
async with orchestrator(test_dataflow) as octx:
127+
async for ctx_str, results in octx.run(test_inputs):
128+
self.assertIn("shout_out", results)
129+
self.assertEqual(
130+
results["shout_out"], "validation_status:_validated"
131+
)

0 commit comments

Comments
 (0)