Skip to content

Commit 3048b1f

Browse files
authored
Merge pull request #1027 from lcy-seso/enable_drop_in_average_and_max_layer
Enable drop in average and max layer
2 parents dbc87e3 + 4375a64 commit 3048b1f

File tree

1 file changed

+31
-62
lines changed

1 file changed

+31
-62
lines changed

python/paddle/trainer/config_parser.py

Lines changed: 31 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2301,14 +2301,9 @@ def Generator(
23012301

23022302
@config_layer('expand')
23032303
class ExpandLayer(LayerBase):
2304-
def __init__(self,
2305-
name,
2306-
inputs,
2307-
trans_type='non-seq',
2308-
device=None,
2309-
bias=False):
2304+
def __init__(self, name, inputs, trans_type='non-seq', bias=False, **xargs):
23102305
super(ExpandLayer, self).__init__(
2311-
name, 'expand', 0, inputs=inputs, device=device)
2306+
name, 'expand', 0, inputs=inputs, **xargs)
23122307
config_assert(
23132308
len(self.inputs) == 2, 'ExpandLayer takes 2 and only 2 inputs')
23142309
self.config.trans_type = trans_type
@@ -2339,11 +2334,10 @@ def __init__(self,
23392334
inputs,
23402335
trans_type='non-seq',
23412336
active_type='linear',
2342-
device=None,
23432337
bias=False,
2344-
output_max_index=None):
2345-
super(MaxLayer, self).__init__(
2346-
name, 'max', 0, inputs=inputs, device=device)
2338+
output_max_index=None,
2339+
**xargs):
2340+
super(MaxLayer, self).__init__(name, 'max', 0, inputs=inputs, **xargs)
23472341
config_assert(len(self.inputs) == 1, 'MaxLayer must have 1 input')
23482342
self.config.trans_type = trans_type
23492343
self.config.active_type = active_type
@@ -2390,15 +2384,15 @@ def __init__(self,
23902384
inputs,
23912385
active_type='linear',
23922386
trans_type='non-seq',
2393-
device=None,
2394-
bias=False):
2387+
bias=False,
2388+
**xargs):
23952389
super(SequenceLastInstanceLayer, self).__init__(
23962390
name,
23972391
'seqlastins',
23982392
0,
23992393
inputs=inputs,
2400-
device=device,
2401-
active_type=active_type)
2394+
active_type=active_type,
2395+
**xargs)
24022396
config_assert(
24032397
len(inputs) == 1, 'SequenceLastInstanceLayer must have 1 input')
24042398
self.config.trans_type = trans_type
@@ -2410,39 +2404,29 @@ def __init__(self,
24102404

24112405
@config_layer('seqfirstins')
24122406
class SequenceFirstInstanceLayer(SequenceLastInstanceLayer):
2413-
def __init__(
2414-
self,
2415-
name,
2416-
inputs,
2417-
active_type='linear',
2418-
trans_type='non-seq',
2419-
device=None,
2420-
bias=False, ):
2407+
def __init__(self,
2408+
name,
2409+
inputs,
2410+
active_type='linear',
2411+
trans_type='non-seq',
2412+
bias=False,
2413+
**xargs):
24212414
super(SequenceFirstInstanceLayer, self).__init__(
2422-
name,
2423-
inputs=inputs,
2424-
active_type=active_type,
2425-
device=device,
2426-
bias=bias)
2415+
name, inputs=inputs, active_type=active_type, bias=bias, **xargs)
24272416
self.config.trans_type = trans_type
24282417
self.config.select_first = True
24292418

24302419

24312420
@config_layer('seqconcat')
24322421
class SequenceConcatLayer(LayerBase):
2433-
def __init__(self,
2434-
name,
2435-
inputs,
2436-
active_type='linear',
2437-
device=None,
2438-
bias=False):
2422+
def __init__(self, name, inputs, active_type='linear', bias=False, **xargs):
24392423
super(SequenceConcatLayer, self).__init__(
24402424
name,
24412425
'seqconcat',
24422426
0,
24432427
inputs=inputs,
2444-
device=device,
2445-
active_type=active_type)
2428+
active_type=active_type,
2429+
**xargs)
24462430
config_assert(
24472431
len(inputs) == 2, 'SequenceConcatLayer must have 2 inputs')
24482432
for input_index in xrange(len(self.inputs)):
@@ -2458,15 +2442,15 @@ def __init__(self,
24582442
size,
24592443
inputs,
24602444
active_type='linear',
2461-
device=None,
2462-
bias=False):
2445+
bias=False,
2446+
**xargs):
24632447
super(SequenceReshapeLayer, self).__init__(
24642448
name,
24652449
'seqreshape',
24662450
size,
24672451
inputs=inputs,
2468-
device=device,
2469-
active_type=active_type)
2452+
active_type=active_type,
2453+
**xargs)
24702454
config_assert(
24712455
len(inputs) == 1, 'SequenceReshapeLayer must have 1 inputs')
24722456
self.set_layer_size(size)
@@ -2475,19 +2459,9 @@ def __init__(self,
24752459

24762460
@config_layer('subseq')
24772461
class SubSequenceLayer(LayerBase):
2478-
def __init__(self,
2479-
name,
2480-
inputs,
2481-
active_type='linear',
2482-
device=None,
2483-
bias=False):
2462+
def __init__(self, name, inputs, active_type='linear', bias=False, **xargs):
24842463
super(SubSequenceLayer, self).__init__(
2485-
name,
2486-
'subseq',
2487-
0,
2488-
inputs=inputs,
2489-
device=device,
2490-
active_type=active_type)
2464+
name, 'subseq', 0, inputs=inputs, active_type=active_type, **xargs)
24912465
config_assert(len(inputs) == 3, 'SubSequenceLayer must have 3 inputs')
24922466
input_layer0 = self.get_input_layer(0)
24932467
size = input_layer0.size
@@ -2644,15 +2618,10 @@ def __init__(self,
26442618
average_strategy='average',
26452619
trans_type='non-seq',
26462620
active_type='linear',
2647-
device=None,
2648-
bias=False):
2621+
bias=False,
2622+
**xargs):
26492623
super(AverageLayer, self).__init__(
2650-
name,
2651-
'average',
2652-
0,
2653-
inputs=inputs,
2654-
device=device,
2655-
active_type=active_type)
2624+
name, 'average', 0, inputs=inputs, active_type=active_type, **xargs)
26562625
self.config.average_strategy = average_strategy
26572626
self.config.trans_type = trans_type
26582627
config_assert(len(inputs) == 1, 'AverageLayer must have 1 input')
@@ -2676,9 +2645,9 @@ def __init__(self, name, inputs, cos_scale=1, device=None):
26762645

26772646
@config_layer('tensor')
26782647
class TensorLayer(LayerBase):
2679-
def __init__(self, name, size, inputs, device=None, bias=True, **xargs):
2648+
def __init__(self, name, size, inputs, bias=True, **xargs):
26802649
super(TensorLayer, self).__init__(
2681-
name, 'tensor', size, inputs=inputs, device=device, **xargs)
2650+
name, 'tensor', size, inputs=inputs, **xargs)
26822651
config_assert(len(self.inputs) == 2, 'TensorLayer must have 2 inputs')
26832652
config_assert(size > 0, 'size must be positive')
26842653
config_assert(inputs[1].parameter_name == None,

0 commit comments

Comments
 (0)