1919 nn .LazyBatchNorm3d ,
2020 nn .SyncBatchNorm ,
2121 nn .RNNBase ,
22- nn .Transformer ,
23- nn .TransformerEncoder ,
24- nn .TransformerDecoder ,
25- nn .TransformerEncoderLayer ,
26- nn .TransformerDecoderLayer ,
2722)
2823
2924_TRACK_RUNNING_STATS_MODULE_TYPES = (
@@ -113,28 +108,36 @@ class Engine:
113108 memory-efficient, and thus typically faster, to use the Gramian-based approach.
114109
115110 .. warning::
116- When providing a non-None ``batch_dim``, all provided modules must respect a few
117- conditions:
111+ When providing a non-None ``batch_dim``, all provided modules must respect a few conditions:
118112
119113 * They should treat the elements of the batch independently. Most common layers respect
120114 this, but for example `BatchNorm
121115 <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
122116 computes some average and standard deviation over the elements of the batch).
123- * Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of
124- tensors, or any nesting of those structures), but each of these tensors must be batched on
125- its first dimension. `Transformers
126- <https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_ and `RNNs
127- <https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ are thus not
128- supported yet. This is only an implementation issue, so it should be fixed soon (please
129- open an issue if you need extra focus on this).
117+ * Their inputs and outputs can be anything, but each input tensor and each output tensor
118+ must be batched on its first dimension. When available (e.g. in `Transformers
119+ <https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_,
120+ `MultiheadAttention
121+ <https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`_,
122+ etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs
123+ <https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ not supported yet
124+ because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``.
130125 * They should not perform in-place operations on tensors (for instance you should not use
131126 ``track_running_stats=True`` in normalization layers).
132127 * They should not have side effects during the forward pass (since their forward pass will
133128 be called twice, the side effects could be different from what's expected).
134129 * If they have some randomness during the forward pass, they should not have direct
135- trainable parameters. It is, however, perfectly fine for random modules to have child
136- modules that have trainable parameters, so if you have a random module with some direct
137- parameters, a simple fix is to wrap these parameters into a child module.
130+ trainable parameters. For this reason,
131+ `Transformers
132+ <https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_, which use a
133+ dropout function (rather than a `Dropout
134+ <https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layer) in a
135+ module with some trainable parameters, has to be used with
136+ ``dropout=0.0``. Note that a `Dropout
137+ <https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layers are
138+ entirely supported and should be preferred. It is also perfectly fine for random modules
139+ to have child modules that have trainable parameters, so if you have a random module with
140+ some direct parameters, a simple fix is to wrap these parameters into a child module.
138141
139142 If you're building your own architecture, respecting those criteria should be quite easy.
140143 However, if you're using an existing architecture, you may have to modify it to make it
@@ -147,6 +150,20 @@ class Engine:
147150 The alternative is to use ``batch_dim=None``, but it's not recommended since it will
148151 increase memory usage by a lot and thus typically slow down computation.
149152
153+ .. warning::
154+ Parent modules should call their child modules directly rather than using their child
155+ modules' parameters themselves. For instance, the following model is not supported:
156+
157+ >>> class Model(nn.Module):
158+ >>> def __init__(self):
159+ >>> super().__init__()
160+ >>> self.linear = nn.Linear(2, 3) # Child module
161+ >>>
162+ >>> def forward(self, input: Tensor) -> Tensor:
163+ >>> # Incorrect: Use the child module's parameters directly without calling it.
164+ >>> return input @ self.linear.weight.T + self.linear.bias
165+ >>> # Correct alternative: return self.linear(input)
166+
150167 .. note::
151168 For maximum efficiency, modules should ideally not contain both direct trainable
152169 parameters and child modules, especially if those direct trainable parameters are used
0 commit comments