Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 03ea6d5

Browse files
authored
df: memory: Cleanup and move methods
- Move {operation_parameter,validator_target}set_pair to MemoryOrchestratorContext - Make removing inputs from input sets async instead of non-async Fixes: #477 Fixes: #452
1 parent edcf1f8 commit 03ea6d5

File tree

5 files changed

+93
-97
lines changed

5 files changed

+93
-97
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
- New model tutorial mentions file paths that should be edited.
2020
- DataFlow is no longer a dataclass to prevent it from being exported
2121
incorrectly.
22+
- `operations_parameter_set_pairs` moved to `MemoryOrchestratorContext`
2223

2324
## [0.3.5] - 2020-03-10
2425
### Added

dffml/df/base.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,19 @@ async def _asdict(self) -> Dict[str, Any]:
482482
item.definition.name: item.value async for item in self.inputs()
483483
}
484484

485+
@abc.abstractmethod
486+
async def remove_input(self, item: Input) -> None:
487+
"""
488+
Removes item from input set
489+
"""
490+
pass
491+
492+
@abc.abstractmethod
493+
async def remove_unvalidated_inputs(self) -> "BaseInputSet":
494+
"""
495+
Removes `unvalidated` inputs from internal list and returns the same.
496+
"""
497+
485498

486499
class BaseParameterSetConfig(NamedTuple):
487500
ctx: BaseInputSetContext
@@ -737,24 +750,6 @@ async def dispatch(
737750
Schedule the running of an operation
738751
"""
739752

740-
@abc.abstractmethod
741-
async def operations_parameter_set_pairs(
742-
self,
743-
ictx: BaseInputNetworkContext,
744-
octx: BaseOperationNetworkContext,
745-
rctx: BaseRedundancyCheckerContext,
746-
ctx: BaseInputSetContext,
747-
*,
748-
new_input_set: BaseInputSet = None,
749-
stage: Stage = Stage.PROCESSING,
750-
) -> AsyncIterator[Tuple[Operation, BaseParameterSet]]:
751-
"""
752-
Use new_input_set to determine which operations in the network might be
753-
up for running. Cross check using existing inputs to generate per
754-
input set context novel input pairings. Yield novel input pairings
755-
along with their operations as they are generated.
756-
"""
757-
758753

759754
# TODO We should be able to specify multiple operation implementation networks.
760755
# This would enable operations to live in different place, accessed via the
@@ -790,6 +785,21 @@ async def run_operations(
790785
Run all the operations then run cleanup and output operations
791786
"""
792787

788+
@abc.abstractmethod
789+
async def operations_parameter_set_pairs(
790+
self,
791+
ctx: BaseInputSetContext,
792+
*,
793+
new_input_set: BaseInputSet = None,
794+
stage: Stage = Stage.PROCESSING,
795+
) -> AsyncIterator[Tuple[Operation, BaseParameterSet]]:
796+
"""
797+
Use new_input_set to determine which operations in the network might be
798+
up for running. Cross check using existing inputs to generate per
799+
input set context novel input pairings. Yield novel input pairings
800+
along with their operations as they are generated.
801+
"""
802+
793803

794804
@base_entry_point("dffml.orchestrator", "orchestrator")
795805
class BaseOrchestrator(BaseDataFlowObject):

dffml/df/memory.py

Lines changed: 61 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ async def inputs(self) -> AsyncIterator[Input]:
127127
for item in self.__inputs:
128128
yield item
129129

130-
def remove_input(self, item: Input):
130+
async def remove_input(self, item: Input):
131131
for x in self.__inputs[:]:
132132
if x.uid == item.uid:
133133
self.__inputs.remove(x)
134134
break
135135

136-
def remove_unvalidated_inputs(self) -> "MemoryInputSet":
136+
async def remove_unvalidated_inputs(self) -> "MemoryInputSet":
137137
"""
138138
Removes `unvalidated` inputs from internal list and returns the same.
139139
"""
@@ -278,7 +278,7 @@ async def add(self, input_set: BaseInputSet):
278278
# self.ctxhd
279279

280280
# remove unvalidated inputs
281-
unvalidated_input_set = input_set.remove_unvalidated_inputs()
281+
unvalidated_input_set = await input_set.remove_unvalidated_inputs()
282282

283283
# If the context for this input set does not exist create a
284284
# NotificationSet for it to notify the orchestrator
@@ -1026,65 +1026,6 @@ async def dispatch(
10261026
task.add_done_callback(ignore_args(self.completed_event.set))
10271027
return task
10281028

1029-
async def operations_parameter_set_pairs(
1030-
self,
1031-
ictx: BaseInputNetworkContext,
1032-
octx: BaseOperationNetworkContext,
1033-
rctx: BaseRedundancyCheckerContext,
1034-
ctx: BaseInputSetContext,
1035-
dataflow: DataFlow,
1036-
*,
1037-
new_input_set: Optional[BaseInputSet] = None,
1038-
stage: Stage = Stage.PROCESSING,
1039-
) -> AsyncIterator[Tuple[Operation, BaseInputSet]]:
1040-
"""
1041-
Use new_input_set to determine which operations in the network might be
1042-
up for running. Cross check using existing inputs to generate per
1043-
input set context novel input pairings. Yield novel input pairings
1044-
along with their operations as they are generated.
1045-
"""
1046-
# Get operations which may possibly run as a result of these new inputs
1047-
async for operation in octx.operations(
1048-
dataflow, input_set=new_input_set, stage=stage
1049-
):
1050-
# Generate all pairs of un-run input combinations
1051-
async for parameter_set in ictx.gather_inputs(
1052-
rctx, operation, dataflow, ctx=ctx
1053-
):
1054-
yield operation, parameter_set
1055-
1056-
async def validator_target_set_pairs(
1057-
self,
1058-
octx: BaseOperationNetworkContext,
1059-
rctx: BaseRedundancyCheckerContext,
1060-
ctx: BaseInputSetContext,
1061-
dataflow: DataFlow,
1062-
unvalidated_input_set: BaseInputSet,
1063-
):
1064-
async for unvalidated_input in unvalidated_input_set.inputs():
1065-
validator_instance_name = unvalidated_input.definition.validate
1066-
validator = dataflow.validators.get(validator_instance_name, None)
1067-
if validator is None:
1068-
raise ValidatorMissing(
1069-
"Validator with instance_name {validator_instance_name} not found"
1070-
)
1071-
# There is only one `input` in `validators`
1072-
input_name, input_definition = list(validator.inputs.items())[0]
1073-
parameter = Parameter(
1074-
key=input_name,
1075-
value=unvalidated_input.value,
1076-
origin=unvalidated_input,
1077-
definition=input_definition,
1078-
)
1079-
parameter_set = MemoryParameterSet(
1080-
MemoryParameterSetConfig(ctx=ctx, parameters=[parameter])
1081-
)
1082-
async for parameter_set, exists in rctx.exists(
1083-
validator, parameter_set
1084-
):
1085-
if not exists:
1086-
yield validator, parameter_set
1087-
10881029

10891030
@entrypoint("memory")
10901031
class MemoryOperationImplementationNetwork(
@@ -1408,6 +1349,60 @@ async def run(
14081349
else:
14091350
task.exception()
14101351

1352+
async def operations_parameter_set_pairs(
1353+
self,
1354+
ctx: BaseInputSetContext,
1355+
dataflow: DataFlow,
1356+
*,
1357+
new_input_set: Optional[BaseInputSet] = None,
1358+
stage: Stage = Stage.PROCESSING,
1359+
) -> AsyncIterator[Tuple[Operation, BaseInputSet]]:
1360+
"""
1361+
Use new_input_set to determine which operations in the network might be
1362+
up for running. Cross check using existing inputs to generate per
1363+
input set context novel input pairings. Yield novel input pairings
1364+
along with their operations as they are generated.
1365+
"""
1366+
# Get operations which may possibly run as a result of these new inputs
1367+
async for operation in self.octx.operations(
1368+
dataflow, input_set=new_input_set, stage=stage
1369+
):
1370+
# Generate all pairs of un-run input combinations
1371+
async for parameter_set in self.ictx.gather_inputs(
1372+
self.rctx, operation, dataflow, ctx=ctx
1373+
):
1374+
yield operation, parameter_set
1375+
1376+
async def validator_target_set_pairs(
1377+
self,
1378+
ctx: BaseInputSetContext,
1379+
dataflow: DataFlow,
1380+
unvalidated_input_set: BaseInputSet,
1381+
):
1382+
async for unvalidated_input in unvalidated_input_set.inputs():
1383+
validator_instance_name = unvalidated_input.definition.validate
1384+
validator = dataflow.validators.get(validator_instance_name, None)
1385+
if validator is None:
1386+
raise ValidatorMissing(
1387+
"Validator with instance_name {validator_instance_name} not found"
1388+
)
1389+
# There is only one `input` in `validators`
1390+
input_name, input_definition = list(validator.inputs.items())[0]
1391+
parameter = Parameter(
1392+
key=input_name,
1393+
value=unvalidated_input.value,
1394+
origin=unvalidated_input,
1395+
definition=input_definition,
1396+
)
1397+
parameter_set = MemoryParameterSet(
1398+
MemoryParameterSetConfig(ctx=ctx, parameters=[parameter])
1399+
)
1400+
async for parameter_set, exists in self.rctx.exists(
1401+
validator, parameter_set
1402+
):
1403+
if not exists:
1404+
yield validator, parameter_set
1405+
14111406
async def run_operations_for_ctx(
14121407
self, ctx: BaseContextHandle, *, strict: bool = True
14131408
) -> AsyncIterator[Tuple[BaseContextHandle, Dict[str, Any]]]:
@@ -1468,9 +1463,7 @@ async def run_operations_for_ctx(
14681463
unvalidated_input_set,
14691464
new_input_set,
14701465
) in new_input_sets:
1471-
async for operation, parameter_set in self.nctx.validator_target_set_pairs(
1472-
self.octx,
1473-
self.rctx,
1466+
async for operation, parameter_set in self.validator_target_set_pairs(
14741467
ctx,
14751468
self.config.dataflow,
14761469
unvalidated_input_set,
@@ -1497,10 +1490,7 @@ async def run_operations_for_ctx(
14971490
)
14981491
# Identify which operations have completed contextually
14991492
# appropriate input sets which haven't been run yet
1500-
async for operation, parameter_set in self.nctx.operations_parameter_set_pairs(
1501-
self.ictx,
1502-
self.octx,
1503-
self.rctx,
1493+
async for operation, parameter_set in self.operations_parameter_set_pairs(
15041494
ctx,
15051495
self.config.dataflow,
15061496
new_input_set=new_input_set,
@@ -1568,13 +1558,8 @@ async def run_operations_for_ctx(
15681558
async def run_stage(self, ctx: BaseInputSetContext, stage: Stage):
15691559
# Identify which operations have complete contextually appropriate
15701560
# input sets which haven't been run yet and are stage operations
1571-
async for operation, parameter_set in self.nctx.operations_parameter_set_pairs(
1572-
self.ictx,
1573-
self.octx,
1574-
self.rctx,
1575-
ctx,
1576-
self.config.dataflow,
1577-
stage=stage,
1561+
async for operation, parameter_set in self.operations_parameter_set_pairs(
1562+
ctx, self.config.dataflow, stage=stage
15781563
):
15791564
# Add inputs and operation to redundancy checker before dispatch
15801565
await self.rctx.add(operation, parameter_set)

dffml/df/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def export(self, *, linked: bool = False):
556556
"operations": {
557557
instance_name: operation.export()
558558
for instance_name, operation in self.operations.items()
559-
},
559+
}
560560
}
561561
if self.seed:
562562
exported["seed"] = self.seed.copy()

model/tensorflow/examples/tfdnnr/tfdnnr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
model = DNNRegressionModel(
66
features=Features(
7-
DefFeature("Feature1", float, 1), DefFeature("Feature2", float, 1),
7+
DefFeature("Feature1", float, 1), DefFeature("Feature2", float, 1)
88
),
99
predict=DefFeature("TARGET", float, 1),
1010
epochs=300,
@@ -20,7 +20,7 @@
2020

2121
# Make prediction
2222
for i, features, prediction in predict(
23-
model, {"Feature1": 0.21, "Feature2": 0.18, "TARGET": 0.84},
23+
model, {"Feature1": 0.21, "Feature2": 0.18, "TARGET": 0.84}
2424
):
2525
features["TARGET"] = prediction["TARGET"]["value"]
2626
print(features)

0 commit comments

Comments
 (0)