Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit d57f427

Browse files
committed
df: input flow: Add get_alternate_definitions helper function
We had a lack of clarity around what was going on with alternate definitions in commit 8657032. There were several if statements where we are looping through the origins and if the origin is a dict, we check if the value is a tuple or a list. If it is not, the key is the operation instance name and the value is the output name for that operation instance. If it is, the key is the origin and the value is the list or tuple of definition names which are acceptable alternate definitions when the input is coming from that origin. This patch adds a helper method to make that more clear. Signed-off-by: John Andersen <[email protected]>
1 parent 9f286de commit d57f427

File tree

2 files changed

+49
-16
lines changed

2 files changed

+49
-16
lines changed

dffml/df/memory.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -490,12 +490,10 @@ async def gather_inputs(
490490
# Ensure all conditions from all origins are True
491491
for origin in origins:
492492
# See comment in input_flow.inputs section
493-
alternate_definitions = []
494-
if isinstance(origin, tuple) and isinstance(
495-
origin[1], (list, tuple)
496-
):
497-
alternate_definitions = origin[1]
498-
origin = origin[0]
493+
(
494+
alternate_definitions,
495+
origin,
496+
) = input_flow.get_alternate_definitions(origin)
499497
# Bail if the condition doesn't exist
500498
if not origin in by_origin:
501499
return
@@ -547,12 +545,10 @@ async def gather_inputs(
547545
# These definitions will be used instead of the
548546
# default one the input specified for the
549547
# operation).
550-
alternate_definitions = []
551-
if isinstance(origin, tuple) and isinstance(
552-
origin[1], (list, tuple)
553-
):
554-
alternate_definitions = origin[1]
555-
origin = origin[0]
548+
(
549+
alternate_definitions,
550+
origin,
551+
) = input_flow.get_alternate_definitions(origin)
556552
# Don't try to grab inputs from an origin that
557553
# doesn't have any to give us
558554
if not origin in by_origin:

dffml/df/types.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,41 @@ def export(self):
429429
def _fromdict(cls, **kwargs):
430430
return cls(**kwargs)
431431

432+
@staticmethod
433+
def get_alternate_definitions(
434+
origin: Tuple[Union[List[str], Tuple[str]], str]
435+
) -> Tuple[Union[List[str], Tuple[str]], str]:
436+
"""
437+
Returns the alternate definitions and origin for an entry within an input
438+
flow. If there are no alternate defintions then the first element of the
439+
returned tuple is an empty list.
440+
441+
Examples
442+
--------
443+
444+
>>> from dffml import InputFlow
445+
>>>
446+
>>> input_flow = InputFlow(
447+
... inputs={
448+
... "features": [
449+
... {"seed": ["Years", "Expertise", "Trust", "Salary"]}
450+
... ],
451+
... "token": [
452+
... "client",
453+
... ]
454+
... }
455+
... )
456+
>>>
457+
>>> input_flow.get_alternate_definitions(list(input_flow.inputs["features"][0].items())[0])
458+
(['Years', 'Expertise', 'Trust', 'Salary'], 'seed')
459+
>>>
460+
input_flow.get_alternate_definitions(list(input_flow.inputs["other"][0].items())[0])
461+
([], 'client')
462+
"""
463+
if isinstance(origin, tuple) and isinstance(origin[1], (list, tuple)):
464+
return origin[1], origin[0]
465+
return [], origin
466+
432467

433468
@dataclass
434469
class Forward:
@@ -600,8 +635,9 @@ def update_by_origin(self):
600635
)
601636
else:
602637
for origin in output_source.items():
603-
if isinstance(origin[1], (list, tuple,)):
604-
origin = origin[0]
638+
_, origin = input_flow.get_alternate_definitions(
639+
origin
640+
)
605641
self.by_origin[operation.stage].setdefault(origin, [])
606642
self.by_origin[operation.stage][origin].append(
607643
operation
@@ -623,8 +659,9 @@ def update_by_origin(self):
623659
# the Input.origin (like "seed"). And the value
624660
# (origin[1]) is the list of definitions which are
625661
# acceptable from that origin for this input.
626-
if isinstance(origin[1], (list, tuple,)):
627-
origin = origin[0]
662+
_, origin = input_flow.get_alternate_definitions(
663+
origin
664+
)
628665
self.by_origin[operation.stage].setdefault(
629666
origin, []
630667
)

0 commit comments

Comments
 (0)