Skip to content

Commit 3a442b2

Browse files
authored
Merge branch 'main' into autogram-readme
2 parents 6e98e60 + 6d0b3a8 commit 3a442b2

File tree

12 files changed

+687
-261
lines changed

12 files changed

+687
-261
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v5.0.0
3+
rev: v6.0.0
44
hooks:
55
- id: trailing-whitespace # Trim trailing whitespace at the end of lines.
66
- id: end-of-file-fixer # Make sure files end in a newline and only a newline.
@@ -19,7 +19,7 @@ repos:
1919
]
2020

2121
- repo: https://github.com/pycqa/isort
22-
rev: 6.0.1
22+
rev: 6.1.0
2323
hooks:
2424
- id: isort # Sort imports.
2525
args: [
@@ -31,8 +31,8 @@ repos:
3131
--ensure-newline-before-comments,
3232
]
3333

34-
- repo: https://github.com/psf/black
35-
rev: 25.1.0
34+
- repo: https://github.com/psf/black-pre-commit-mirror
35+
rev: 25.9.0
3636
hooks:
3737
- id: black # Format code.
3838
args: [--line-length=100]

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,6 @@ full = [
103103
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
104104
"ecos>=2.0.14", # Does not work before 2.0.14
105105
]
106+
107+
[tool.pytest.ini_options]
108+
xfail_strict = true

src/torchjd/autogram/_engine.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
nn.LazyBatchNorm3d,
2020
nn.SyncBatchNorm,
2121
nn.RNNBase,
22-
nn.Transformer,
23-
nn.TransformerEncoder,
24-
nn.TransformerDecoder,
25-
nn.TransformerEncoderLayer,
26-
nn.TransformerDecoderLayer,
2722
)
2823

2924
_TRACK_RUNNING_STATS_MODULE_TYPES = (
@@ -113,28 +108,36 @@ class Engine:
113108
memory-efficient, and thus typically faster, to use the Gramian-based approach.
114109
115110
.. warning::
116-
When providing a non-None ``batch_dim``, all provided modules must respect a few
117-
conditions:
111+
When providing a non-None ``batch_dim``, all provided modules must respect a few conditions:
118112
119113
* They should treat the elements of the batch independently. Most common layers respect
120114
this, but for example `BatchNorm
121115
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
122116
computes some average and standard deviation over the elements of the batch).
123-
* Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of
124-
tensors, or any nesting of those structures), but each of these tensors must be batched on
125-
its first dimension. `Transformers
126-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_ and `RNNs
127-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ are thus not
128-
supported yet. This is only an implementation issue, so it should be fixed soon (please
129-
open an issue if you need extra focus on this).
117+
* Their inputs and outputs can be anything, but each input tensor and each output tensor
118+
must be batched on its first dimension. When available (e.g. in `Transformers
119+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_,
120+
`MultiheadAttention
121+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`_,
122+
etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs
123+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ not supported yet
124+
because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``.
130125
* They should not perform in-place operations on tensors (for instance you should not use
131126
``track_running_stats=True`` in normalization layers).
132127
* They should not have side effects during the forward pass (since their forward pass will
133128
be called twice, the side effects could be different from what's expected).
134129
* If they have some randomness during the forward pass, they should not have direct
135-
trainable parameters. It is, however, perfectly fine for random modules to have child
136-
modules that have trainable parameters, so if you have a random module with some direct
137-
parameters, a simple fix is to wrap these parameters into a child module.
130+
trainable parameters. For this reason,
131+
`Transformers
132+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_, which use a
133+
dropout function (rather than a `Dropout
134+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layer) in a
135+
module with some trainable parameters, has to be used with
136+
``dropout=0.0``. Note that a `Dropout
137+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layers are
138+
entirely supported and should be preferred. It is also perfectly fine for random modules
139+
to have child modules that have trainable parameters, so if you have a random module with
140+
some direct parameters, a simple fix is to wrap these parameters into a child module.
138141
139142
If you're building your own architecture, respecting those criteria should be quite easy.
140143
However, if you're using an existing architecture, you may have to modify it to make it
@@ -147,6 +150,20 @@ class Engine:
147150
The alternative is to use ``batch_dim=None``, but it's not recommended since it will
148151
increase memory usage by a lot and thus typically slow down computation.
149152
153+
.. warning::
154+
Parent modules should call their child modules directly rather than using their child
155+
modules' parameters themselves. For instance, the following model is not supported:
156+
157+
>>> class Model(nn.Module):
158+
>>> def __init__(self):
159+
>>> super().__init__()
160+
>>> self.linear = nn.Linear(2, 3) # Child module
161+
>>>
162+
>>> def forward(self, input: Tensor) -> Tensor:
163+
>>> # Incorrect: Use the child module's parameters directly without calling it.
164+
>>> return input @ self.linear.weight.T + self.linear.bias
165+
>>> # Correct alternative: return self.linear(input)
166+
150167
.. note::
151168
For maximum efficiency, modules should ideally not contain both direct trainable
152169
parameters and child modules, especially if those direct trainable parameters are used

0 commit comments

Comments
 (0)