Skip to content

Commit d5ba6e4

Browse files
authored
Merge pull request #102 from HDI-Project/issue-96-allow-passing-fit-and-produce-args-as-init_params
Allow passing fit and produce args as init params
2 parents ab1d483 + c78c137 commit d5ba6e4

File tree

5 files changed

+96
-34
lines changed

5 files changed

+96
-34
lines changed

mlblocks/mlblock.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313

1414
def import_object(object_name):
1515
"""Import an object from its Fully Qualified Name."""
16-
package, name = object_name.rsplit('.', 1)
17-
return getattr(importlib.import_module(package), name)
16+
if isinstance(object_name, str):
17+
package, name = object_name.rsplit('.', 1)
18+
return getattr(importlib.import_module(package), name)
19+
20+
return object_name
1821

1922

2023
class MLBlock():
@@ -27,7 +30,7 @@ class MLBlock():
2730
2831
Attributes:
2932
name (str):
30-
Name given to this MLBlock.
33+
Primitive name.
3134
metadata (dict):
3235
Additional information about this primitive
3336
primitive (object):
@@ -46,8 +49,8 @@ class MLBlock():
4649
function.
4750
4851
Args:
49-
name (str):
50-
Name given to this MLBlock.
52+
primitive (str or dict):
53+
primitive name or primitive dictionary.
5154
**kwargs:
5255
Any additional arguments that will be used as hyperparameters or passed to the
5356
``fit`` or ``produce`` methods.
@@ -143,10 +146,12 @@ def _get_tunable(cls, hyperparameters, init_params):
143146

144147
return tunable
145148

146-
def __init__(self, name, **kwargs):
147-
self.name = name
149+
def __init__(self, primitive, **kwargs):
150+
if isinstance(primitive, str):
151+
primitive = load_primitive(primitive)
148152

149-
self.metadata = load_primitive(name)
153+
self.metadata = primitive
154+
self.name = primitive['name']
150155

151156
self.primitive = import_object(self.metadata['primitive'])
152157

@@ -252,11 +257,9 @@ def _get_method_kwargs(self, kwargs, method_args):
252257

253258
if name in kwargs:
254259
value = kwargs[name]
255-
256260
elif 'default' in arg:
257261
value = arg['default']
258-
259-
else:
262+
elif arg.get('required', True):
260263
raise TypeError("missing expected argument '{}'".format(name))
261264

262265
method_kwargs[keyword] = value

mlblocks/mlpipeline.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,21 @@ def _build_blocks(self):
8787

8888
block_names_count = Counter()
8989
for primitive in self.primitives:
90+
if isinstance(primitive, str):
91+
primitive_name = primitive
92+
else:
93+
primitive_name = primitive['name']
94+
9095
try:
91-
block_names_count.update([primitive])
92-
block_count = block_names_count[primitive]
93-
block_name = '{}#{}'.format(primitive, block_count)
96+
block_names_count.update([primitive_name])
97+
block_count = block_names_count[primitive_name]
98+
block_name = '{}#{}'.format(primitive_name, block_count)
9499
block_params = self.init_params.get(block_name, dict())
95100
if not block_params:
96-
block_params = self.init_params.get(primitive, dict())
101+
block_params = self.init_params.get(primitive_name, dict())
97102
if block_params and block_count > 1:
98103
LOGGER.warning(("Non-numbered init_params are being used "
99-
"for more than one block %s."), primitive)
104+
"for more than one block %s."), primitive_name)
100105

101106
block = MLBlock(primitive, **block_params)
102107
blocks[block_name] = block
@@ -330,10 +335,6 @@ def _get_block_args(self, block_name, block_args, context):
330335

331336
if variable in context:
332337
kwargs[name] = context[variable]
333-
elif 'default' in arg:
334-
kwargs[name] = arg['default']
335-
elif arg.get('required', True):
336-
raise ValueError('Input variable {} not found in context'.format(variable))
337338

338339
return kwargs
339340

@@ -517,11 +518,12 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
517518
the value of that variable from the context will extracted and returned
518519
after the produce method of that block has been called.
519520
"""
520-
context = {
521-
'X': X,
522-
'y': y
523-
}
524-
context.update(kwargs)
521+
context = kwargs.copy()
522+
if X is not None:
523+
context['X'] = X
524+
525+
if y is not None:
526+
context['y'] = y
525527

526528
output_block, output_variable = self._get_output_spec(output_)
527529
last_block_name = self._get_block_name(-1)
@@ -624,10 +626,9 @@ def predict(self, X=None, output_=None, start_=None, **kwargs):
624626
the value of that variable from the context will extracted and returned
625627
after the produce method of that block has been called.
626628
"""
627-
context = {
628-
'X': X
629-
}
630-
context.update(kwargs)
629+
context = kwargs.copy()
630+
if X is not None:
631+
context['X'] = X
631632

632633
output_block, output_variable = self._get_output_spec(output_)
633634

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from mlblocks.mlpipeline import MLPipeline
2+
3+
4+
def test_fit_predict_args_in_init():
5+
6+
def add(a, b):
7+
return a + b
8+
9+
primitive = {
10+
'name': 'add',
11+
'primitive': add,
12+
'produce': {
13+
'args': [
14+
{
15+
'name': 'a',
16+
'type': 'float',
17+
},
18+
{
19+
'name': 'b',
20+
'type': 'float',
21+
},
22+
],
23+
'output': [
24+
{
25+
'type': 'float',
26+
'name': 'out'
27+
}
28+
]
29+
}
30+
}
31+
32+
primitives = [primitive]
33+
init_params = {
34+
'add': {
35+
'b': 10
36+
}
37+
}
38+
pipeline = MLPipeline(primitives, init_params=init_params)
39+
40+
out = pipeline.predict(a=3)
41+
42+
assert out == 13

tests/test_mlblock.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def test__get_tunable_condition_match_null(self):
323323
@patch('mlblocks.mlblock.load_primitive')
324324
def test___init__(self, load_primitive_mock, import_object_mock, set_hps_mock):
325325
load_primitive_mock.return_value = {
326+
'name': 'a_primitive_name',
326327
'primitive': 'a_primitive_name',
327328
'produce': {
328329
'args': [
@@ -335,9 +336,22 @@ def test___init__(self, load_primitive_mock, import_object_mock, set_hps_mock):
335336
}
336337
}
337338

338-
mlblock = MLBlock('given_primitive_name', argument='value')
339+
mlblock = MLBlock('a_primitive_name', argument='value')
339340

340-
assert mlblock.name == 'given_primitive_name'
341+
assert mlblock.metadata == {
342+
'name': 'a_primitive_name',
343+
'primitive': 'a_primitive_name',
344+
'produce': {
345+
'args': [
346+
{
347+
'name': 'argument'
348+
}
349+
],
350+
'output': [
351+
]
352+
}
353+
}
354+
assert mlblock.name == 'a_primitive_name'
341355
assert mlblock.primitive == import_object_mock.return_value
342356
assert mlblock._fit == dict()
343357
assert mlblock.fit_args == list()
@@ -370,22 +384,24 @@ def test___init__(self, load_primitive_mock, import_object_mock, set_hps_mock):
370384
@patch('mlblocks.mlblock.load_primitive')
371385
def test___str__(self, load_primitive_mock, import_object_mock):
372386
load_primitive_mock.return_value = {
387+
'name': 'a_primitive_name',
373388
'primitive': 'a_primitive_name',
374389
'produce': {
375390
'args': [],
376391
'output': []
377392
}
378393
}
379394

380-
mlblock = MLBlock('given_primitive_name')
395+
mlblock = MLBlock('a_primitive_name')
381396

382-
assert str(mlblock) == 'MLBlock - given_primitive_name'
397+
assert str(mlblock) == 'MLBlock - a_primitive_name'
383398

384399
@patch('mlblocks.mlblock.import_object')
385400
@patch('mlblocks.mlblock.load_primitive')
386401
def test_get_tunable_hyperparameters(self, load_primitive_mock, import_object_mock):
387402
"""get_tunable_hyperparameters has to return a copy of the _tunables attribute."""
388403
load_primitive_mock.return_value = {
404+
'name': 'a_primitive_name',
389405
'primitive': 'a_primitive_name',
390406
'produce': {
391407
'args': [],
@@ -433,6 +449,7 @@ def primitive(a_list_param):
433449
io_mock.return_value = primitive
434450

435451
lp_mock.return_value = {
452+
'name': 'a_primitive',
436453
'primitive': 'a_primitive',
437454
'produce': {
438455
'args': [],

tests/test_mlpipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def test__get_block_args(self):
270270

271271
expected = {
272272
'arg_1': 'arg_1_value',
273-
'arg_2': 'arg_2_value',
274273
'arg_3': 'arg_3_value',
275274
}
276275
assert args == expected

0 commit comments

Comments
 (0)