Skip to content

Commit 566a940

Browse files
authored
Implement a bilinear initializer for transposed convolution to do upsampling. (#11404)
* Implement a bilinear initializer for transposed convolution. * Update some error message.
1 parent cc1239f commit 566a940

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

python/paddle/fluid/initializer.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import framework
1616
import numpy as np
1717
import contextlib
18+
from framework import convert_np_dtype_to_dtype_
19+
from core import VarDesc
1820

1921
__all__ = [
20-
'Constant', 'Uniform', 'Normal', 'Xavier', 'force_init_on_cpu',
22+
'Constant', 'Uniform', 'Normal', 'Xavier', 'Bilinear', 'force_init_on_cpu',
2123
'init_on_cpu', 'ConstantInitializer', 'UniformInitializer',
22-
'NormalInitializer', 'XavierInitializer'
24+
'NormalInitializer', 'XavierInitializer', 'BilinearInitializer'
2325
]
2426

2527
_force_init_on_cpu_ = False
@@ -422,6 +424,101 @@ def __call__(self, var, block):
422424
return op
423425

424426

427+
class BilinearInitializer(Initializer):
428+
"""Implements the bilinear initializer.
429+
430+
This initializer can be used in transposed convolution operator to
431+
act as upsampling. Users can upsample a feature map with shape of
432+
(B, C, H, W) by any integer factor. The usage is:
433+
434+
>>> factor = 2
435+
>>> w_attr = ParamAttr(learning_rate=0., regularizer=L2Decay(0.),
436+
>>> initializer=Bilinear())
437+
>>> conv_up = fluid.layers.conv2d_transpose(
438+
>>> input,
439+
>>> num_filters=C,
440+
>>> output_size=None,
441+
>>> filter_size=2 * factor - factor % 2,
442+
>>> padding=ceil((factor - 1) / 2.),
443+
>>> stride=factor,
444+
>>> groups=C,
445+
>>> param_attr=w_attr,
446+
>>> bias_attr=False)
447+
448+
449+
Where, `num_filters=C` and `groups=C` means this is channel-wise tranposed
450+
convolution. The filter shape will be (C, 1, K, K) where K is `filer_size`,
451+
This initializer will set a (K, K) interpolation kernel for every channel
452+
of the filter identically. The resulting shape of the output feature map
453+
will be (B, C, factor * H, factor * W). Note that the learning rate and the
454+
weight decay are set to 0 in order to keep coefficient values of bilinear
455+
interpolation unchanged during training.
456+
"""
457+
458+
def __init__(self):
459+
"""Constructor for BilinearInitializer.
460+
"""
461+
super(BilinearInitializer, self).__init__()
462+
463+
def __call__(self, var, block):
464+
"""Add biliear initialization ops for a variable
465+
466+
Args:
467+
var (Variable): Variable that needs to be initialized.
468+
block (Block): The block in which initialization ops should
469+
be added.
470+
471+
Returns:
472+
the initialization op
473+
474+
Raises:
475+
ValueError: If type of `var` and `block` is not right.
476+
If the shape of `var` size is not 4 and
477+
var.shape[2] != var.shape[3].
478+
"""
479+
if not isinstance(var, framework.Variable):
480+
raise ValueError("var must be framework.Variable.")
481+
482+
if not isinstance(block, framework.Block):
483+
raise ValueError("block must be framework.Block.")
484+
485+
shape = var.shape
486+
if len(shape) != 4:
487+
raise ValueError("the length of shape must be 4.")
488+
if shape[2] != shape[3]:
489+
raise ValueError("shape[2] must be equal to shape[3].")
490+
491+
weight = np.zeros(np.prod(var.shape), dtype='float32')
492+
size = shape[3]
493+
# factor
494+
f = np.ceil(size / 2.)
495+
# center
496+
c = (2 * f - 1 - f % 2) / (2. * f)
497+
for i in range(np.prod(shape)):
498+
x = i % size
499+
y = (i / size) % size
500+
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
501+
weight = np.reshape(weight, shape)
502+
503+
if var.dtype == VarDesc.VarType.FP32:
504+
value_name = "fp32_values"
505+
values = [float(v) for v in weight.flat]
506+
else:
507+
raise ValueError("Unsupported dtype %s", input.dtype)
508+
if np.prod(shape) > 1024 * 1024:
509+
raise ValueError("The size of input is too big. ")
510+
op = block.append_op(
511+
type='assign_value',
512+
outputs={'Out': [var]},
513+
attrs={
514+
'dtype': var.dtype,
515+
'shape': list(shape),
516+
value_name: values
517+
})
518+
var.op = op
519+
return op
520+
521+
425522
# We short the class name, since users will use the initializer with the package
426523
# name. The sample code:
427524
#
@@ -436,3 +533,4 @@ def __call__(self, var, block):
436533
Normal = NormalInitializer
437534
Xavier = XavierInitializer
438535
MSRA = MSRAInitializer
536+
Bilinear = BilinearInitializer

python/paddle/fluid/tests/unittests/test_initializer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,5 +364,22 @@ def test_msra_initializer_supplied_arguments(self):
364364
self.assertEqual(init_op.attr('seed'), 134)
365365

366366

367+
class TestMSRAInitializer(unittest.TestCase):
368+
def test_bilinear_initializer(self):
369+
"""Test the bilinear initializer with supplied arguments
370+
"""
371+
program = framework.Program()
372+
block = program.global_block()
373+
block.create_parameter(
374+
dtype="float32",
375+
shape=[8, 1, 3, 3],
376+
lod_level=0,
377+
name="param",
378+
initializer=initializer.BilinearInitializer())
379+
self.assertEqual(len(block.ops), 1)
380+
init_op = block.ops[0]
381+
self.assertEqual(init_op.type, 'assign_value')
382+
383+
367384
if __name__ == '__main__':
368385
unittest.main()

0 commit comments

Comments
 (0)