Skip to content

Commit 98bc889

Browse files
Haonanemailweixu
authored andcommitted
split the input list of conv_operator into two inputs: image and filter (#104)
1 parent b130ba7 commit 98bc889

File tree

1 file changed

+9
-7
lines changed
  • python/paddle/trainer_config_helpers

1 file changed

+9
-7
lines changed

python/paddle/trainer_config_helpers/layers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,7 +2667,7 @@ def __add_evaluator__(e):
26672667

26682668
return LayerOutput(name, LayerType.COST, parents=[input, label])
26692669

2670-
def conv_operator(input, filter_size, num_filters,
2670+
def conv_operator(img, filter, filter_size, num_filters,
26712671
num_channel=None, stride=1, padding=0, groups=1,
26722672
filter_size_y=None, stride_y=None, padding_y=None):
26732673
"""
@@ -2680,13 +2680,16 @@ def conv_operator(input, filter_size, num_filters,
26802680
26812681
.. code-block:: python
26822682
2683-
op = conv_operator(input=[layer1, layer2],
2683+
op = conv_operator(img=input1,
2684+
filter=input2,
26842685
filter_size=3.0,
26852686
num_filters=64,
26862687
num_channels=64)
26872688
2688-
:param input: Input layer.
2689-
:type input: LayerOutput|list|tuple
2689+
:param img: input image
2690+
:type img: LayerOutput
2691+
:param filter: input filter
2692+
:type filter: LayerOutput
26902693
:param filter_size: The x dimension of a filter kernel.
26912694
:type filter_size: int
26922695
:param filter_size_y: The y dimension of a filter kernel. Since
@@ -2708,14 +2711,13 @@ def conv_operator(input, filter_size, num_filters,
27082711
:return: A ConvOperator Object.
27092712
:rtype: ConvOperator
27102713
"""
2711-
assert isinstance(input, list) or isinstance(input, tuple)
27122714
if filter_size_y is None:
27132715
filter_size_y = filter_size
27142716
if stride_y is None:
27152717
stride_y = stride
27162718
if padding_y is None:
27172719
padding_y = padding
2718-
op = ConvOperator(input_layer_names=[x.name for x in input],
2720+
op = ConvOperator(input_layer_names=[img.name, filter.name],
27192721
num_filters = num_filter,
27202722
conv_conf=Conv(filter_size=filter_size,
27212723
padding=padding,
@@ -2725,7 +2727,7 @@ def conv_operator(input, filter_size, num_filters,
27252727
padding_y=padding_y,
27262728
stride_y=stride_y,
27272729
groups=groups))
2728-
op.origin = input
2730+
op.origin = [img, filter]
27292731
op.origin.operator = "conv_op"
27302732
return op
27312733

0 commit comments

Comments
 (0)