Skip to content

Commit 9e7002f

Browse files
authored
fix(pu): fix last_linear_layer_weight_bias_init_zero in MLP and add its unittest (#650)
* fix(pu): fix last_linear_layer_weight_bias_init_zero in MLP and add its unittest * polish(pu): polish unittest of mlp * style(pu): yapf format * style(pu): flake8 format * polish(pu): polish the output_activation and output_norm in MLP * style(pu): polish the annotations in MLP, yapf format * style(pu): flake8 style fix * fix(pu): fix output_activation and output_norm in MLP
1 parent aefddac commit 9e7002f

File tree

2 files changed

+79
-45
lines changed

2 files changed

+79
-45
lines changed

ding/torch_utils/network/nn_module.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ def MLP(
314314
norm_type: str = None,
315315
use_dropout: bool = False,
316316
dropout_probability: float = 0.5,
317-
output_activation: nn.Module = None,
318-
output_norm_type: str = None,
317+
output_activation: bool = True,
318+
output_norm: bool = True,
319319
last_linear_layer_init_zero: bool = False
320320
):
321321
r"""
@@ -328,15 +328,18 @@ def MLP(
328328
- hidden_channels (:obj:`int`): Number of channels in the hidden tensor.
329329
- out_channels (:obj:`int`): Number of channels in the output tensor.
330330
- layer_num (:obj:`int`): Number of layers.
331-
- layer_fn (:obj:`Callable`): layer function.
332-
- activation (:obj:`nn.Module`): the optional activation function.
333-
- norm_type (:obj:`str`): type of the normalization.
334-
- use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block.
335-
- dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5.
336-
- output_activation (:obj:`nn.Module`): the optional activation function in the last layer.
337-
- output_norm_type (:obj:`str`): type of the normalization in the last layer.
338-
- last_linear_layer_init_zero (:obj:`bool`): zero initialization for the last linear layer (including w and b),
339-
which can provide stable zero outputs in the beginning.
331+
- layer_fn (:obj:`Callable`): Layer function.
332+
- activation (:obj:`nn.Module`): The optional activation function.
333+
- norm_type (:obj:`str`): The type of the normalization.
334+
- use_dropout (:obj:`bool`): Whether to use dropout in the fully-connected block.
335+
- dropout_probability (:obj:`float`): The probability of an element to be zeroed in the dropout. Default: 0.5.
336+
- output_activation (:obj:`bool`): Whether to use activation in the output layer. If True,
337+
we use the same activation as front layers. Default: True.
338+
- output_norm (:obj:`bool`): Whether to use normalization in the output layer. If True,
339+
we use the same normalization as front layers. Default: True.
340+
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last linear layer
341+
(including w and b), which can provide stable zero outputs in the beginning,
342+
usually used in the policy network in RL settings.
340343
Returns:
341344
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block.
342345
@@ -361,30 +364,31 @@ def MLP(
361364
if use_dropout:
362365
block.append(nn.Dropout(dropout_probability))
363366

364-
# the last layer
367+
# The last layer
365368
in_channels = channels[-2]
366369
out_channels = channels[-1]
367-
if output_activation is None and output_norm_type is None:
368-
# the last layer use the same norm and activation as front layers
369-
block.append(layer_fn(in_channels, out_channels))
370+
block.append(layer_fn(in_channels, out_channels))
371+
"""
372+
In the final layer of a neural network, whether to use normalization and activation are typically determined
373+
based on user specifications. These specifications depend on the problem at hand and the desired properties of
374+
the model's output.
375+
"""
376+
if output_norm is True:
377+
# The last layer uses the same norm as front layers.
370378
if norm_type is not None:
371379
block.append(build_normalization(norm_type, dim=1)(out_channels))
380+
if output_activation is True:
381+
# The last layer uses the same activation as front layers.
372382
if activation is not None:
373383
block.append(activation)
374-
if use_dropout:
375-
block.append(nn.Dropout(dropout_probability))
376-
else:
377-
# the last layer use the specific norm and activation
378-
block.append(layer_fn(in_channels, out_channels))
379-
if output_norm_type is not None:
380-
block.append(build_normalization(output_norm_type, dim=1)(out_channels))
381-
if output_activation is not None:
382-
block.append(output_activation)
383-
if use_dropout:
384-
block.append(nn.Dropout(dropout_probability))
385-
if last_linear_layer_init_zero:
386-
block[-2].weight.data.fill_(0)
387-
block[-2].bias.data.fill_(0)
384+
385+
if last_linear_layer_init_zero:
386+
# Locate the last linear layer and initialize its weights and biases to 0.
387+
for _, layer in enumerate(reversed(block)):
388+
if isinstance(layer, nn.Linear):
389+
nn.init.zeros_(layer.weight)
390+
nn.init.zeros_(layer.bias)
391+
break
388392

389393
return sequential_pack(block)
390394

ding/torch_utils/network/tests/test_nn_module.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import torch
21
import pytest
3-
from ding.torch_utils import build_activation, build_normalization
2+
import torch
3+
from torch.testing import assert_allclose
4+
5+
from ding.torch_utils import build_activation
46
from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \
57
ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \
68
normed_linear, normed_conv2d
@@ -44,20 +46,48 @@ def test_weight_init(self):
4446
weight_init_(weight, 'xxx')
4547

4648
def test_mlp(self):
47-
input = torch.rand(batch_size, in_channels).requires_grad_(True)
48-
block = MLP(
49-
in_channels=in_channels,
50-
hidden_channels=hidden_channels,
51-
out_channels=out_channels,
52-
layer_num=2,
53-
activation=torch.nn.ReLU(inplace=True),
54-
norm_type='BN',
55-
output_activation=torch.nn.Identity(),
56-
output_norm_type=None,
57-
last_linear_layer_init_zero=True
58-
)
59-
output = self.run_model(input, block)
60-
assert output.shape == (batch_size, out_channels)
49+
layer_num = 3
50+
input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True)
51+
52+
for output_activation in [True, False]:
53+
for output_norm in [True, False]:
54+
for activation in [torch.nn.ReLU(), torch.nn.LeakyReLU(), torch.nn.Tanh(), None]:
55+
for norm_type in ["LN", "BN", None]:
56+
# Test case 1: MLP without last linear layer initialized to 0.
57+
model = MLP(
58+
in_channels,
59+
hidden_channels,
60+
out_channels,
61+
layer_num,
62+
activation=activation,
63+
norm_type=norm_type,
64+
output_activation=output_activation,
65+
output_norm=output_norm
66+
)
67+
output_tensor = self.run_model(input_tensor, model)
68+
assert output_tensor.shape == (batch_size, out_channels)
69+
70+
# Test case 2: MLP with last linear layer initialized to 0.
71+
model = MLP(
72+
in_channels,
73+
hidden_channels,
74+
out_channels,
75+
layer_num,
76+
activation=activation,
77+
norm_type=norm_type,
78+
output_activation=output_activation,
79+
output_norm=output_norm,
80+
last_linear_layer_init_zero=True
81+
)
82+
output_tensor = self.run_model(input_tensor, model)
83+
assert output_tensor.shape == (batch_size, out_channels)
84+
last_linear_layer = None
85+
for layer in reversed(model):
86+
if isinstance(layer, torch.nn.Linear):
87+
last_linear_layer = layer
88+
break
89+
assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight))
90+
assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias))
6191

6292
def test_conv1d_block(self):
6393
length = 2

0 commit comments

Comments
 (0)