Skip to content

Commit 098302e

Browse files
Implement dynamic inputs and outputs. (#135)
* Implement dynamic inputs and outputs. * Recover block_outputs if it's a string from the block's instance. * Update tests
1 parent ae9653b commit 098302e

File tree

3 files changed

+136
-35
lines changed

3 files changed

+136
-35
lines changed

mlblocks/mlblock.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,15 @@ def _extract_params(self, kwargs, hyperparameters):
111111
if name in kwargs:
112112
init_params[name] = kwargs.pop(name)
113113

114-
fit_args = [arg['name'] for arg in self.fit_args]
115-
produce_args = [arg['name'] for arg in self.produce_args]
114+
if not isinstance(self.fit_args, str):
115+
fit_args = [arg['name'] for arg in self.fit_args]
116+
else:
117+
fit_args = []
118+
119+
if not isinstance(self.produce_args, str):
120+
produce_args = [arg['name'] for arg in self.produce_args]
121+
else:
122+
produce_args = []
116123

117124
for name in list(kwargs.keys()):
118125
if name in fit_args:
@@ -257,6 +264,8 @@ def _get_method_kwargs(self, kwargs, method_args):
257264
A dictionary containing the argument names and values to pass
258265
to the primitive method.
259266
"""
267+
if isinstance(method_args, str):
268+
method_args = getattr(self.instance, method_args)()
260269

261270
method_kwargs = dict()
262271
for arg in method_args:

mlblocks/mlpipeline.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def _get_block_variables(self, block_name, variables_attr, names):
177177
"""
178178
block = self.blocks[block_name]
179179
variables = deepcopy(getattr(block, variables_attr))
180+
if isinstance(variables, str):
181+
variables = getattr(block.instance, variables)()
182+
180183
variable_dict = {}
181184
for variable in variables:
182185
name = variable['name']
@@ -300,6 +303,12 @@ def get_inputs(self, fit=True):
300303

301304
return inputs
302305

306+
def get_fit_args(self):
307+
return list(self.get_inputs(fit=True).values())
308+
309+
def get_predict_args(self):
310+
return list(self.get_inputs(fit=False).values())
311+
303312
def get_outputs(self, outputs='default'):
304313
"""Get the list of output variables that correspond to the specified outputs.
305314
@@ -578,6 +587,10 @@ def _get_block_args(self, block_name, block_args, context):
578587

579588
input_names = self.input_names.get(block_name, dict())
580589

590+
if isinstance(block_args, str):
591+
block = self.blocks[block_name]
592+
block_args = getattr(block.instance, block_args)()
593+
581594
kwargs = dict()
582595
for arg in block_args:
583596
name = arg['name']
@@ -591,6 +604,9 @@ def _get_block_args(self, block_name, block_args, context):
591604
def _extract_outputs(self, block_name, outputs, block_outputs):
592605
"""Extract the outputs of the method as a dict to be set into the context."""
593606
# TODO: type validation and/or transformation should be done here
607+
if isinstance(block_outputs, str):
608+
block = self.blocks[block_name]
609+
block_outputs = getattr(block.instance, block_outputs)()
594610

595611
if not isinstance(outputs, tuple):
596612
outputs = (outputs, )

tests/test_mlpipeline.py

Lines changed: 109 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def test_get_outputs_str_named(self):
381381
]
382382
}
383383
pipeline = MLPipeline(['a_primitive', 'another_primitive'], outputs=outputs)
384+
384385
returned = pipeline.get_outputs('debug')
385386

386387
expected = [
@@ -389,13 +390,11 @@ def test_get_outputs_str_named(self):
389390
'variable': 'another_variable',
390391
}
391392
]
392-
393393
assert returned == expected
394394

395395
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
396396
def test_get_outputs_str_variable(self):
397397
pipeline = MLPipeline(['a_primitive', 'another_primitive'])
398-
399398
pipeline.blocks['a_primitive#1'].produce_output = [
400399
{
401400
'name': 'output',
@@ -412,7 +411,6 @@ def test_get_outputs_str_variable(self):
412411
'variable': 'a_primitive#1.output'
413412
}
414413
]
415-
416414
assert returned == expected
417415

418416
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
@@ -427,7 +425,6 @@ def test_get_outputs_str_block(self):
427425
'variable': 'a_primitive#1',
428426
}
429427
]
430-
431428
assert returned == expected
432429

433430
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
@@ -442,7 +439,6 @@ def test_get_outputs_int(self):
442439
'variable': 'another_primitive#1',
443440
}
444441
]
445-
446442
assert returned == expected
447443

448444
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
@@ -463,7 +459,6 @@ def test_get_outputs_combination(self):
463459
]
464460
}
465461
pipeline = MLPipeline(['a_primitive', 'another_primitive'], outputs=outputs)
466-
467462
pipeline.blocks['a_primitive#1'].produce_output = [
468463
{
469464
'name': 'output',
@@ -498,7 +493,6 @@ def test_get_outputs_combination(self):
498493
'variable': 'a_primitive#1.output'
499494
}
500495
]
501-
502496
assert returned == expected
503497

504498
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
@@ -550,32 +544,90 @@ def test_get_output_variables(self):
550544
assert names == ['a_variable']
551545

552546
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
553-
def test__get_block_variables(self):
547+
def test__get_block_variables_is_dict(self):
548+
pipeline = MLPipeline(['a_primitive'])
549+
pipeline.blocks['a_primitive#1'].produce_outputs = [
550+
{
551+
'name': 'output',
552+
'type': 'whatever'
553+
}
554+
]
555+
556+
outputs = pipeline._get_block_variables(
557+
'a_primitive#1',
558+
'produce_outputs',
559+
{'output': 'name_output'}
560+
)
561+
554562
expected = {
555563
'name_output': {
556564
'name': 'output',
557565
'type': 'whatever',
558566
}
559567
}
568+
assert outputs == expected
560569

570+
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
571+
def test__get_block_variables_is_str(self):
561572
pipeline = MLPipeline(['a_primitive'])
562-
563-
pipeline.blocks['a_primitive#1'].produce_outputs = [
573+
pipeline.blocks['a_primitive#1'].produce_outputs = 'get_produce_outputs'
574+
pipeline.blocks['a_primitive#1'].instance.get_produce_outputs.return_value = [
564575
{
565-
'name': 'output',
566-
'type': 'whatever'
576+
'name': 'output_from_function',
577+
'type': 'test'
567578
}
579+
568580
]
569581

570582
outputs = pipeline._get_block_variables(
571583
'a_primitive#1',
572584
'produce_outputs',
573585
{'output': 'name_output'}
574586
)
587+
588+
expected = {
589+
'output_from_function': {
590+
'name': 'output_from_function',
591+
'type': 'test',
592+
}
593+
}
575594
assert outputs == expected
595+
pipeline.blocks['a_primitive#1'].instance.get_produce_outputs.assert_called_once_with()
576596

577597
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
578598
def test_get_inputs_fit(self):
599+
pipeline = MLPipeline(['a_primitive', 'another_primitive'])
600+
pipeline.blocks['a_primitive#1'].produce_args = [
601+
{
602+
'name': 'input',
603+
'type': 'whatever'
604+
}
605+
]
606+
pipeline.blocks['a_primitive#1'].fit_args = [
607+
{
608+
'name': 'fit_input',
609+
'type': 'whatever'
610+
}
611+
]
612+
pipeline.blocks['a_primitive#1'].produce_output = [
613+
{
614+
'name': 'output',
615+
'type': 'another_whatever'
616+
}
617+
]
618+
pipeline.blocks['another_primitive#1'].produce_args = [
619+
{
620+
'name': 'output',
621+
'type': 'another_whatever'
622+
},
623+
{
624+
'name': 'another_input',
625+
'type': 'another_whatever'
626+
}
627+
]
628+
629+
inputs = pipeline.get_inputs()
630+
579631
expected = {
580632
'input': {
581633
'name': 'input',
@@ -589,32 +641,30 @@ def test_get_inputs_fit(self):
589641
'name': 'another_input',
590642
'type': 'another_whatever',
591643
}
592-
593644
}
645+
assert inputs == expected
594646

647+
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
648+
def test_get_inputs_no_fit(self):
595649
pipeline = MLPipeline(['a_primitive', 'another_primitive'])
596-
597650
pipeline.blocks['a_primitive#1'].produce_args = [
598651
{
599652
'name': 'input',
600653
'type': 'whatever'
601654
}
602655
]
603-
604656
pipeline.blocks['a_primitive#1'].fit_args = [
605657
{
606658
'name': 'fit_input',
607659
'type': 'whatever'
608660
}
609661
]
610-
611662
pipeline.blocks['a_primitive#1'].produce_output = [
612663
{
613664
'name': 'output',
614665
'type': 'another_whatever'
615666
}
616667
]
617-
618668
pipeline.blocks['another_primitive#1'].produce_args = [
619669
{
620670
'name': 'output',
@@ -626,11 +676,8 @@ def test_get_inputs_fit(self):
626676
}
627677
]
628678

629-
inputs = pipeline.get_inputs()
630-
assert inputs == expected
679+
inputs = pipeline.get_inputs(fit=False)
631680

632-
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
633-
def test_get_inputs_no_fit(self):
634681
expected = {
635682
'input': {
636683
'name': 'input',
@@ -640,46 +687,75 @@ def test_get_inputs_no_fit(self):
640687
'name': 'another_input',
641688
'type': 'another_whatever',
642689
}
643-
644690
}
691+
assert inputs == expected
645692

646-
pipeline = MLPipeline(['a_primitive', 'another_primitive'])
647-
693+
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
694+
def test_get_fit_args(self):
695+
pipeline = MLPipeline(['a_primitive'])
648696
pipeline.blocks['a_primitive#1'].produce_args = [
649697
{
650698
'name': 'input',
651699
'type': 'whatever'
652700
}
653701
]
654-
655702
pipeline.blocks['a_primitive#1'].fit_args = [
656703
{
657704
'name': 'fit_input',
658705
'type': 'whatever'
659706
}
660707
]
661-
662708
pipeline.blocks['a_primitive#1'].produce_output = [
663709
{
664710
'name': 'output',
665711
'type': 'another_whatever'
666712
}
667713
]
668714

669-
pipeline.blocks['another_primitive#1'].produce_args = [
715+
outputs = pipeline.get_fit_args()
716+
717+
expected = [
670718
{
671-
'name': 'output',
672-
'type': 'another_whatever'
719+
'name': 'input',
720+
'type': 'whatever'
673721
},
674722
{
675-
'name': 'another_input',
676-
'type': 'another_whatever'
723+
'name': 'fit_input',
724+
'type': 'whatever',
677725
}
678726
]
727+
assert outputs == expected
679728

680-
inputs = pipeline.get_inputs(fit=False)
729+
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
730+
def test_get_predict_args(self):
731+
pipeline = MLPipeline(['a_primitive'])
732+
pipeline.blocks['a_primitive#1'].produce_args = [
733+
{
734+
'name': 'input',
735+
'type': 'whatever'
736+
}
737+
]
738+
pipeline.blocks['a_primitive#1'].fit_args = [
739+
{
740+
'name': 'fit_input',
741+
'type': 'whatever'
742+
}
743+
]
744+
pipeline.blocks['a_primitive#1'].produce_output = [
745+
{
746+
'name': 'output',
747+
'type': 'another_whatever'
748+
}
749+
]
750+
outputs = pipeline.get_predict_args()
681751

682-
assert inputs == expected
752+
expected = [
753+
{
754+
'name': 'input',
755+
'type': 'whatever'
756+
}
757+
]
758+
assert outputs == expected
683759

684760
@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
685761
def test_fit_pending_all_primitives(self):

0 commit comments

Comments
 (0)