Skip to content

Commit 1a0eb09

Browse files
committed
Allow getting full context in partial outputs
1 parent d790938 commit 1a0eb09

File tree

3 files changed

+171
-72
lines changed

3 files changed

+171
-72
lines changed

mlblocks/mlpipeline.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import logging
7+
import re
78
from collections import Counter, OrderedDict, defaultdict
89
from copy import deepcopy
910

@@ -198,12 +199,15 @@ def __init__(self, pipeline=None, primitives=None, init_params=None,
198199
if hyperparameters:
199200
self.set_hyperparameters(hyperparameters)
200201

202+
self._re_block_name = re.compile(r'(^[^#]+#\d+)(\..*)?')
203+
201204
def _get_str_output(self, output):
202205
"""Get the outputs that correspond to the str specification."""
203206
if output in self.outputs:
204207
return self.outputs[output]
205208
elif output in self.blocks:
206-
return self._get_block_outputs(output)
209+
return [{'name': output, 'variable': output}]
210+
# return self._get_block_outputs(output)
207211
elif '.' in output:
208212
block_name, variable_name = output.rsplit('.', 1)
209213
block = self.blocks.get(block_name)
@@ -260,11 +264,11 @@ def get_outputs(self, outputs='default'):
260264

261265
computed = list()
262266
for output in outputs:
267+
if isinstance(output, int):
268+
output = self._get_block_name(output)
269+
263270
if isinstance(output, str):
264271
computed.extend(self._get_str_output(output))
265-
elif isinstance(output, int):
266-
block_name = self._get_block_name(output)
267-
computed.extend(self._get_block_outputs(block_name))
268272
else:
269273
raise TypeError('Output Specification can only be str or int')
270274

@@ -316,6 +320,18 @@ def get_output_variables(self, outputs='default'):
316320
outputs = self.get_outputs(outputs)
317321
return [output['variable'] for output in outputs]
318322

323+
def _extract_block_name(self, variable_name):
324+
return self._re_block_name.search(variable_name).group(1)
325+
326+
def _prepare_outputs(self, outputs):
327+
output_variables = self.get_output_variables(outputs)
328+
outputs = output_variables.copy()
329+
output_blocks = {
330+
self._extract_block_name(variable)
331+
for variable in output_variables
332+
}
333+
return output_variables, outputs, output_blocks
334+
319335
@staticmethod
320336
def _flatten_dict(hyperparameters):
321337
return {
@@ -519,13 +535,11 @@ def _extract_outputs(self, block_name, outputs, block_outputs):
519535

520536
return output_dict
521537

522-
def _update_outputs(self, block_name, output_variables, outputs, outputs_dict):
538+
def _update_outputs(self, variable_name, output_variables, outputs, value):
523539
"""Set the requested block outputs into the outputs list in the right place."""
524-
for key, value in outputs_dict.items():
525-
variable_name = '{}.{}'.format(block_name, key)
526-
if variable_name in output_variables:
527-
index = output_variables.index(variable_name)
528-
outputs[index] = deepcopy(value)
540+
if variable_name in output_variables:
541+
index = output_variables.index(variable_name)
542+
outputs[index] = deepcopy(value)
529543

530544
def _fit_block(self, block, block_name, context):
531545
"""Get the block args from the context and fit the block."""
@@ -554,7 +568,12 @@ def _produce_block(self, block, block_name, context, output_variables, outputs):
554568
context.update(outputs_dict)
555569

556570
if output_variables:
557-
self._update_outputs(block_name, output_variables, outputs, outputs_dict)
571+
if block_name in output_variables:
572+
self._update_outputs(block_name, output_variables, outputs, context)
573+
else:
574+
for key, value in outputs_dict.items():
575+
variable_name = '{}.{}'.format(block_name, key)
576+
self._update_outputs(variable_name, output_variables, outputs, value)
558577

559578
except Exception:
560579
if self.verbose:
@@ -609,17 +628,12 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
609628
if y is not None:
610629
context['y'] = y
611630

612-
if isinstance(output_, str):
613-
output_variables = self.get_output_variables(output_)
614-
outputs = output_variables.copy()
615-
output_blocks = {variable.rsplit('.', 1)[0] for variable in output_variables}
616-
else:
631+
if output_ is None:
617632
output_variables = None
618633
outputs = None
619634
output_blocks = set()
620-
621-
if isinstance(output_, int):
622-
output_ = self._get_block_name(output_)
635+
else:
636+
output_variables, outputs, output_blocks = self._prepare_outputs(output_)
623637

624638
if isinstance(start_, int):
625639
start_ = self._get_block_name(start_)
@@ -634,15 +648,12 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
634648

635649
self._fit_block(block, block_name, context)
636650

637-
last_block = block_name != self._last_block_name
638-
if last_block or (block_name == output_) or (block_name in output_blocks):
651+
if (block_name != self._last_block_name) or (block_name in output_blocks):
639652
self._produce_block(block, block_name, context, output_variables, outputs)
640653

641654
# We already captured the output from this block
642655
if block_name in output_blocks:
643656
output_blocks.remove(block_name)
644-
elif block_name == output_:
645-
return context
646657

647658
# If there was an output_ but there are no pending
648659
# outputs we are done.
@@ -695,9 +706,7 @@ def predict(self, X=None, output_='default', start_=None, **kwargs):
695706
if X is not None:
696707
context['X'] = X
697708

698-
output_variables = self.get_output_variables(output_)
699-
outputs = output_variables.copy()
700-
output_blocks = {variable.rsplit('.', 1)[0] for variable in output_variables}
709+
output_variables, outputs, output_blocks = self._prepare_outputs(output_)
701710

702711
if isinstance(start_, int):
703712
start_ = self._get_block_name(start_)

tests/features/test_partial_outputs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,14 @@ def test_fit_output(self):
7070
y = np.array([
7171
0, 0, 0, 0, 1
7272
])
73+
context = {'X': X, 'y': y}
7374

7475
almost_equal(named_out, y)
7576
assert len(list_out) == 2
7677
almost_equal(list_out[0], y)
77-
almost_equal(list_out[1], X)
78-
almost_equal(X, int_out)
79-
almost_equal(X, str_out)
78+
almost_equal(list_out[1], context)
79+
almost_equal(context, int_out)
80+
almost_equal(context, str_out)
8081
almost_equal(X, str_out_variable)
8182
assert no_output is None
8283

0 commit comments

Comments
 (0)