|
| 1 | +.. _cn_api_paddle_nn_ModuleDict: |
| 2 | + |
| 3 | +ModuleDict |
| 4 | +------------------------------- |
| 5 | + |
| 6 | +.. py:class:: paddle.nn.ModuleDict(modules=None) |
| 7 | +
|
| 8 | +
|
| 9 | +
|
| 10 | +
|
| 11 | +ModuleDict 用于保存子层到有序字典中,它包含的子层将被正确地注册和添加。列表中的子层可以像常规 python 有序字典一样被访问。 |
| 12 | + |
| 13 | +.. note:: |
| 14 | + ``LayerDict`` 是 ``ModuleDict`` 的别名,两者在使用和功能上完全等价。 |
| 15 | + |
| 16 | +参数 |
| 17 | +:::::::::::: |
| 18 | + |
| 19 | + - **modules** (ModuleDict|OrderedDict|list[(key, Module)],可选) - 键值对的可迭代对象,值的类型为 `paddle.nn.Module` 。 |
| 20 | + |
| 21 | + |
| 22 | +代码示例 |
| 23 | +:::::::::::: |
| 24 | + |
| 25 | +.. code-block:: python |
| 26 | +
|
| 27 | + >>> import paddle |
| 28 | + >>> import numpy as np |
| 29 | + >>> from collections import OrderedDict |
| 30 | +
|
| 31 | + >>> modules = OrderedDict([ |
| 32 | + ... ('conv1d', paddle.nn.Conv1D(3, 2, 3)), |
| 33 | + ... ('conv2d', paddle.nn.Conv2D(3, 2, 3)), |
| 34 | + ... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))), |
| 35 | + >>> ]) |
| 36 | +
|
| 37 | + >>> modules_dict = paddle.nn.ModuleDict(modules=modules) |
| 38 | +
|
| 39 | + >>> l = modules_dict['conv1d'] |
| 40 | +
|
| 41 | + >>> for k in modules_dict: |
| 42 | + ... l = modules_dict[k] |
| 43 | + ... |
| 44 | + >>> print(len(modules_dict)) |
| 45 | + 3 |
| 46 | +
|
| 47 | + >>> del modules_dict['conv2d'] |
| 48 | + >>> print(len(modules_dict)) |
| 49 | + 2 |
| 50 | +
|
| 51 | + >>> conv1d = modules_dict.pop('conv1d') |
| 52 | + >>> print(len(modules_dict)) |
| 53 | + 1 |
| 54 | +
|
| 55 | + >>> modules_dict.clear() |
| 56 | + >>> print(len(modules_dict)) |
| 57 | + 0 |
| 58 | +
|
| 59 | +方法 |
| 60 | +:::::::::::: |
| 61 | +clear() |
| 62 | +''''''''' |
| 63 | + |
| 64 | +清除 ModuleDict 中所有的子层。 |
| 65 | + |
| 66 | +**参数** |
| 67 | + |
| 68 | + 无。 |
| 69 | + |
| 70 | +**代码示例** |
| 71 | + |
| 72 | +.. code-block:: python |
| 73 | +
|
| 74 | + >>> import paddle |
| 75 | + >>> from collections import OrderedDict |
| 76 | +
|
| 77 | + >>> modules = OrderedDict([ |
| 78 | + ... ('conv1d', paddle.nn.Conv1D(3, 2, 3)), |
| 79 | + ... ('conv2d', paddle.nn.Conv2D(3, 2, 3)), |
| 80 | + ... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))), |
| 81 | + >>> ]) |
| 82 | +
|
| 83 | + >>> module_dict = paddle.nn.ModuleDict(modules=modules) |
| 84 | + >>> len(module_dict) |
| 85 | + 3 |
| 86 | +
|
| 87 | + >>> module_dict.clear() |
| 88 | + >>> len(module_dict) |
| 89 | + 0 |
| 90 | +
|
| 91 | +pop() |
| 92 | +''''''''' |
| 93 | + |
| 94 | +移除 ModuleDict 中的键 并且返回该键对应的子层。 |
| 95 | + |
| 96 | +**参数** |
| 97 | + |
| 98 | + - **key** (str) - 要移除的 key。 |
| 99 | + |
| 100 | +**代码示例** |
| 101 | + |
| 102 | +.. code-block:: python |
| 103 | +
|
| 104 | + >>> import paddle |
| 105 | + >>> from collections import OrderedDict |
| 106 | +
|
| 107 | + >>> modules = OrderedDict([ |
| 108 | + ... ('conv1d', paddle.nn.Conv1D(3, 2, 3)), |
| 109 | + ... ('conv2d', paddle.nn.Conv2D(3, 2, 3)), |
| 110 | + ... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))), |
| 111 | + >>> ]) |
| 112 | +
|
| 113 | + >>> module_dict = paddle.nn.ModuleDict(modules=modules) |
| 114 | + >>> len(module_dict) |
| 115 | + 3 |
| 116 | +
|
| 117 | + >>> module_dict.pop('conv2d') |
| 118 | + >>> len(module_dict) |
| 119 | + 2 |
| 120 | +
|
| 121 | +keys() |
| 122 | +''''''''' |
| 123 | + |
| 124 | +返回 ModuleDict 中键的可迭代对象。 |
| 125 | + |
| 126 | +**参数** |
| 127 | + |
| 128 | + 无。 |
| 129 | + |
| 130 | +**代码示例** |
| 131 | + |
| 132 | +.. code-block:: python |
| 133 | +
|
| 134 | + >>> import paddle |
| 135 | + >>> from collections import OrderedDict |
| 136 | +
|
| 137 | + >>> modules = OrderedDict([ |
| 138 | + ... ('conv1d', paddle.nn.Conv1D(3, 2, 3)), |
| 139 | + ... ('conv2d', paddle.nn.Conv2D(3, 2, 3)), |
| 140 | + ... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))), |
| 141 | + >>> ]) |
| 142 | +
|
| 143 | + >>> module_dict = paddle.nn.ModuleDict(modules=modules) |
| 144 | + >>> for k in module_dict.keys(): |
| 145 | + ... print(k) |
| 146 | + conv1d |
| 147 | + conv2d |
| 148 | + conv3d |
| 149 | +
|
| 150 | +
|
| 151 | +items() |
| 152 | +''''''''' |
| 153 | + |
| 154 | +返回 ModuleDict 中键/值对的可迭代对象。 |
| 155 | + |
| 156 | +**参数** |
| 157 | + |
| 158 | + 无。 |
| 159 | + |
| 160 | +**代码示例** |
| 161 | + |
| 162 | +.. code-block:: python |
| 163 | +
|
| 164 | + >>> import paddle |
| 165 | + >>> from collections import OrderedDict |
| 166 | +
|
| 167 | + >>> modules = OrderedDict([ |
| 168 | + ... ('conv1d', paddle.nn.Conv1D(3, 2, 3)), |
| 169 | + ... ('conv2d', paddle.nn.Conv2D(3, 2, 3)), |
| 170 | + ... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))), |
| 171 | + >>> ]) |
| 172 | +
|
| 173 | + >>> module_dict = paddle.nn.ModuleDict(modules=modules) |
| 174 | + >>> for k, v in module_dict.items(): |
| 175 | + ... print(f"{k}:", v) |
| 176 | + conv1d : Conv1D(3, 2, kernel_size=[3], data_format=NCL) |
| 177 | + conv2d : Conv2D(3, 2, kernel_size=[3, 3], data_format=NCHW) |
| 178 | + conv3d : Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW) |
| 179 | +
|
| 180 | +
|
| 181 | +values() |
| 182 | +''''''''' |
| 183 | + |
| 184 | +返回 ModuleDict 中值的可迭代对象。 |
| 185 | + |
| 186 | +**参数** |
| 187 | + |
| 188 | + 无。 |
| 189 | + |
| 190 | +**代码示例** |
| 191 | + |
| 192 | +.. code-block:: python |
| 193 | +
|
| 194 | + >>> import paddle |
| 195 | + >>> from collections import OrderedDict |
| 196 | +
|
| 197 | + >>> modules = OrderedDict([ |
| 198 | + ... ('conv1d', paddle.nn.Conv1D(3, 2, 3)), |
| 199 | + ... ('conv2d', paddle.nn.Conv2D(3, 2, 3)), |
| 200 | + ... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))), |
| 201 | + >>> ]) |
| 202 | +
|
| 203 | + >>> module_dict = paddle.nn.ModuleDict(modules=modules) |
| 204 | + >>> for v in module_dict.values(): |
| 205 | + ... print(v) |
| 206 | + Conv1D(3, 2, kernel_size=[3], data_format=NCL) |
| 207 | + Conv2D(3, 2, kernel_size=[3, 3], data_format=NCHW) |
| 208 | + Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW) |
| 209 | +
|
| 210 | +
|
| 211 | +update() |
| 212 | +''''''''' |
| 213 | + |
| 214 | +更新子层中的键/值对到 ModuleDict 中,会覆盖已经存在的键。 |
| 215 | + |
| 216 | +**参数** |
| 217 | + |
| 218 | + - **sublayers** (ModuleDict|OrderedDict|list[(key, Module)]) - 键值对的可迭代对象,值的类型为 `paddle.nn.Module` 。 |
| 219 | + |
| 220 | +**代码示例** |
| 221 | + |
| 222 | +.. code-block:: python |
| 223 | +
|
| 224 | + >>> import paddle |
| 225 | + >>> from collections import OrderedDict |
| 226 | +
|
| 227 | + >>> modules = OrderedDict([ |
| 228 | + ... ('conv1d', paddle.nn.Conv1D(3, 2, 3)), |
| 229 | + ... ('conv2d', paddle.nn.Conv2D(3, 2, 3)), |
| 230 | + ... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))), |
| 231 | + >>> ]) |
| 232 | +
|
| 233 | + >>> new_submodules = OrderedDict([ |
| 234 | + ... ('relu', paddle.nn.ReLU()), |
| 235 | + ... ('conv2d', paddle.nn.Conv2D(4, 2, 4)), |
| 236 | + >>> ]) |
| 237 | + >>> module_dict = paddle.nn.ModuleDict(modules=modules) |
| 238 | +
|
| 239 | + >>> module_dict.update(new_submodules) |
| 240 | +
|
| 241 | + >>> for k, v in module_dict.items(): |
| 242 | + ... print(f"{k}:", v) |
| 243 | + conv1d : Conv1D(3, 2, kernel_size=[3], data_format=NCL) |
| 244 | + conv2d : Conv2D(4, 2, kernel_size=[4, 4], data_format=NCHW) |
| 245 | + conv3d : Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW) |
| 246 | + relu : ReLU() |
0 commit comments