Skip to content

Commit d790938

Browse files
committed
New partial output with context - WIP
1 parent 2b5d790 commit d790938

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

mlblocks/mlpipeline.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,11 @@ def _get_block_outputs(self, block_name):
145145
"""Get the list of output variables for the given block."""
146146
block = self.blocks[block_name]
147147
outputs = deepcopy(block.produce_output)
148+
output_names = self.output_names.get(block_name, dict())
148149
for output in outputs:
149-
output['variable'] = '{}.{}'.format(block_name, output['name'])
150+
name = output['name']
151+
context_name = output_names.get(name, name)
152+
output['variable'] = '{}.{}'.format(block_name, context_name)
150153

151154
return outputs
152155

@@ -606,7 +609,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
606609
if y is not None:
607610
context['y'] = y
608611

609-
if output_ is not None:
612+
if isinstance(output_, str):
610613
output_variables = self.get_output_variables(output_)
611614
outputs = output_variables.copy()
612615
output_blocks = {variable.rsplit('.', 1)[0] for variable in output_variables}
@@ -615,6 +618,9 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
615618
outputs = None
616619
output_blocks = set()
617620

621+
if isinstance(output_, int):
622+
output_ = self._get_block_name(output_)
623+
618624
if isinstance(start_, int):
619625
start_ = self._get_block_name(start_)
620626

@@ -628,16 +634,19 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
628634

629635
self._fit_block(block, block_name, context)
630636

631-
if (block_name != self._last_block_name) or (block_name in output_blocks):
637+
last_block = block_name != self._last_block_name
638+
if last_block or (block_name == output_) or (block_name in output_blocks):
632639
self._produce_block(block, block_name, context, output_variables, outputs)
633640

634641
# We already captured the output from this block
635642
if block_name in output_blocks:
636643
output_blocks.remove(block_name)
644+
elif block_name == output_:
645+
return context
637646

638647
# If there was an output_ but there are no pending
639648
# outputs we are done.
640-
if output_ is not None and not output_blocks:
649+
if output_variables is not None and not output_blocks:
641650
if len(outputs) > 1:
642651
return tuple(outputs)
643652
else:

0 commit comments

Comments
 (0)