Skip to content

Commit a211197

Browse files
authored
feat(autogram): Add Transformer support (#447)
* Refactor how used params are deduced from the module: we now combine the direct params and the indirectly used params, and we reuse the same code between both usages of this function * Add special case for indirectly used params of MultiheadAttention * Add WithMultiheadAttention, WithTransformer, and WithTransformerLarge tests * Add WithTransformerLarge speed test * Rename trainable to rg * Add test_batched_non_batched_equivalence_2 * Update warnings about Transformers in the docstring of Engine
1 parent 425691b commit a211197

File tree

7 files changed

+192
-25
lines changed

7 files changed

+192
-25
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,36 @@ class Engine:
113113
memory-efficient, and thus typically faster, to use the Gramian-based approach.
114114
115115
.. warning::
116-
When providing a non-None ``batch_dim``, all provided modules must respect a few
117-
conditions:
116+
When providing a non-None ``batch_dim``, all provided modules must respect a few conditions:
118117
119118
* They should treat the elements of the batch independently. Most common layers respect
120119
this, but for example `BatchNorm
121120
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
122121
computes some average and standard deviation over the elements of the batch).
123122
* Their inputs and outputs can be anything, but each input tensor and each output tensor
124-
must be batched on its first dimension. `Transformers
125-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_ and `RNNs
126-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ are thus not
127-
supported yet. This is only an implementation issue, so it should be fixed soon (please
128-
open an issue if you need extra focus on this).
123+
must be batched on its first dimension. When available (e.g. in `Transformers
124+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_,
125+
`MultiheadAttention
126+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`_,
127+
etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs
128+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ not supported yet
129+
because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``.
129130
* They should not perform in-place operations on tensors (for instance you should not use
130131
``track_running_stats=True`` in normalization layers).
131132
* They should not have side effects during the forward pass (since their forward pass will
132133
be called twice, the side effects could be different from what's expected).
133134
* If they have some randomness during the forward pass, they should not have direct
134-
trainable parameters. It is, however, perfectly fine for random modules to have child
135-
modules that have trainable parameters, so if you have a random module with some direct
136-
parameters, a simple fix is to wrap these parameters into a child module.
135+
trainable parameters. For this reason,
136+
`Transformers
137+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_, which use a
138+
dropout function (rather than a `Dropout
139+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layer) in a
140+
module with some trainable parameters, has to be used with
141+
``dropout=0.0``. Note that a `Dropout
142+
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layers are
143+
entirely supported and should be preferred. It is also perfectly fine for random modules
144+
to have child modules that have trainable parameters, so if you have a random module with
145+
some direct parameters, a simple fix is to wrap these parameters into a child module.
137146
138147
If you're building your own architecture, respecting those criteria should be quite easy.
139148
However, if you're using an existing architecture, you may have to modify it to make it

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ._edge_registry import EdgeRegistry
1111
from ._gramian_accumulator import GramianAccumulator
12+
from ._module_utils import get_used_params
1213
from ._vjp import VJP, AutogradVJP, FunctionalVJP
1314

1415
# Note about import from protected _pytree module:
@@ -125,8 +126,8 @@ def __call__(
125126
# require grad
126127
return outputs
127128

128-
requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
129-
self.gramian_accumulator.track_parameter_paths(requires_grad_params)
129+
rg_params, _ = get_used_params(module)
130+
self.gramian_accumulator.track_parameter_paths(rg_params.values())
130131

131132
# We only care about running the JacobianAccumulator node, so we need one of its child
132133
# edges (the edges of the original outputs of the model) as target. For memory
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from torch import nn
2+
3+
4+
def get_used_params(module: nn.Module) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]:
5+
"""
6+
Gets all parameters that a module uses. In reality, we return all direct params (which may
7+
include some unused params) and all the indirectly used params that we know about (we may be
8+
missing some in weird modules).
9+
10+
Returns the tuple containing the params that require grad and the params that don't require
11+
grad.
12+
"""
13+
14+
direct_rg_params, direct_frozen_params = _get_direct_params(module)
15+
indirect_rg_params, indirect_frozen_params = _get_indirectly_used_params(module)
16+
rg_params = direct_rg_params | indirect_rg_params
17+
frozen_params = direct_frozen_params | indirect_frozen_params
18+
19+
return rg_params, frozen_params
20+
21+
22+
def _get_direct_params(
23+
module: nn.Module, prefix: str = ""
24+
) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]:
25+
rg_params = dict[str, nn.Parameter]()
26+
frozen_params = dict[str, nn.Parameter]()
27+
28+
for name, param in module.named_parameters(recurse=False):
29+
if param.requires_grad:
30+
rg_params[prefix + name] = param
31+
else:
32+
frozen_params[prefix + name] = param
33+
34+
return rg_params, frozen_params
35+
36+
37+
def _get_indirectly_used_params(
38+
module: nn.Module,
39+
) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]:
40+
# MHA uses its out_proj child params itself. Note that we also check that the MHA still has
41+
# an out_proj attribute because it might change in the future (which will remove the
42+
# necessity of custom code for MHA entirely). See the status of
43+
# https://github.com/pytorch/pytorch/pull/126568
44+
if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"):
45+
return _get_direct_params(module.out_proj, prefix="out_proj.")
46+
47+
return {}, {}

src/torchjd/autogram/_vjp.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from torch.nn import Parameter
77
from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten
88

9+
from torchjd.autogram._module_utils import get_used_params
10+
911
# Note about import from protected _pytree module:
1012
# PyTorch maintainers plan to make pytree public (see
1113
# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400).
@@ -37,14 +39,7 @@ class ModuleVJP(VJP, ABC):
3739

3840
def __init__(self, module: nn.Module):
3941
self.module = module
40-
self.trainable_params = dict[str, Parameter]()
41-
self.frozen_params = dict[str, Parameter]()
42-
43-
for name, param in module.named_parameters(recurse=False):
44-
if param.requires_grad:
45-
self.trainable_params[name] = param
46-
else:
47-
self.frozen_params[name] = param
42+
self.rg_params, self.frozen_params = get_used_params(module)
4843

4944

5045
class FunctionalVJP(ModuleVJP):
@@ -78,9 +73,9 @@ def _call_on_one_instance(
7873
kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j)
7974
grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j]
8075

81-
def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
76+
def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]:
8277
all_state = {
83-
**trainable_params,
78+
**rg_params,
8479
**dict(self.module.named_buffers()),
8580
**self.frozen_params,
8681
}
@@ -89,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor
8984
rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad]
9085
return rg_outputs
9186

92-
vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]
87+
vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]
9388

9489
# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
9590
# functional has a single primal which is dict(module.named_parameters()). We therefore take
@@ -109,14 +104,14 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
109104
super().__init__(module)
110105

111106
self.rg_outputs = rg_outputs
112-
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)
107+
self.flat_rg_params, self.param_spec = tree_flatten(self.rg_params)
113108

114109
def __call__(
115110
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree]
116111
) -> dict[str, Tensor]:
117112
grads = torch.autograd.grad(
118113
self.rg_outputs,
119-
self.flat_trainable_params,
114+
self.flat_rg_params,
120115
grad_outputs,
121116
retain_graph=True,
122117
allow_unused=True,

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NoFreeParam,
1414
ShapedModule,
1515
SqueezeNet,
16+
WithTransformerLarge,
1617
)
1718
from utils.forward_backwards import (
1819
autograd_forward_backward,
@@ -27,6 +28,7 @@
2728
from torchjd.autogram import Engine
2829

2930
PARAMETRIZATIONS = [
31+
(WithTransformerLarge, 8),
3032
(FreeParam, 64),
3133
(NoFreeParam, 64),
3234
(Cifar10Model, 64),

tests/unit/autogram/test_engine.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,13 @@
5757
WithModuleWithStringArg,
5858
WithModuleWithStringKwarg,
5959
WithModuleWithStringOutput,
60+
WithMultiHeadAttention,
6061
WithNoTensorOutput,
6162
WithRNN,
6263
WithSideEffect,
6364
WithSomeFrozenModule,
65+
WithTransformer,
66+
WithTransformerLarge,
6467
)
6568
from utils.dict_assertions import assert_tensor_dicts_are_close
6669
from utils.forward_backwards import (
@@ -118,6 +121,8 @@
118121
(WithModuleWithStringOutput, 32),
119122
(WithModuleWithStringKwarg, 32),
120123
(WithModuleWithHybridPyTreeKwarg, 32),
124+
(WithMultiHeadAttention, 32),
125+
param(WithTransformer, 32, marks=mark.filterwarnings("ignore:There is a performance drop")),
121126
(FreeParam, 32),
122127
(NoFreeParam, 32),
123128
param(Cifar10Model, 16, marks=mark.slow),
@@ -126,6 +131,11 @@
126131
param(GroupNormMobileNetV3Small, 3, marks=mark.slow),
127132
param(SqueezeNet, 8, marks=mark.slow),
128133
param(InstanceNormMobileNetV2, 2, marks=mark.slow),
134+
param(
135+
WithTransformerLarge,
136+
8,
137+
marks=[mark.slow, mark.filterwarnings("ignore:There is a performance drop")],
138+
),
129139
]
130140

131141

@@ -565,3 +575,42 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
565575
gramian2 = engine2.compute_gramian(output)
566576

567577
assert_close(gramian1, gramian2)
578+
579+
580+
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
581+
def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int):
582+
"""
583+
Same as test_batched_non_batched_equivalence but on real architectures, and thus only between
584+
batch_size=0 and batch_size=None.
585+
586+
If for some architecture this test passes but the test_compute_gramian doesn't pass, it could be
587+
that the get_used_params does not work for some module of the architecture.
588+
"""
589+
590+
input_shapes = architecture.INPUT_SHAPES
591+
output_shapes = architecture.OUTPUT_SHAPES
592+
593+
torch.manual_seed(0)
594+
model_0 = architecture().to(device=DEVICE)
595+
torch.manual_seed(0)
596+
model_none = architecture().to(device=DEVICE)
597+
598+
engine_0 = Engine(model_0.modules(), batch_dim=0)
599+
engine_none = Engine(model_none.modules(), batch_dim=None)
600+
601+
inputs = make_tensors(batch_size, input_shapes)
602+
targets = make_tensors(batch_size, output_shapes)
603+
loss_fn = make_mse_loss_fn(targets)
604+
605+
torch.random.manual_seed(0) # Fix randomness for random models
606+
output = model_0(inputs)
607+
losses_0 = reduce_to_vector(loss_fn(output))
608+
609+
torch.random.manual_seed(0) # Fix randomness for random models
610+
output = model_none(inputs)
611+
losses_none = reduce_to_vector(loss_fn(output))
612+
613+
gramian_0 = engine_0.compute_gramian(losses_0)
614+
gramian_none = engine_none.compute_gramian(losses_none)
615+
616+
assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)

tests/utils/architectures.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,70 @@ def forward(self, input: Tensor) -> Tensor:
931931
return output
932932

933933

934+
class WithMultiHeadAttention(ShapedModule):
935+
"""Module containing a MultiheadAttention layer."""
936+
937+
INPUT_SHAPES = ((20, 8), (10, 9), (10, 11))
938+
OUTPUT_SHAPES = (20, 8)
939+
940+
def __init__(self):
941+
super().__init__()
942+
self.mha = nn.MultiheadAttention(
943+
embed_dim=8,
944+
num_heads=2,
945+
dropout=0.0,
946+
batch_first=True,
947+
kdim=9,
948+
vdim=11,
949+
)
950+
951+
def forward(self, input: tuple[Tensor, Tensor, Tensor]) -> Tensor:
952+
query, key, value = input
953+
attn_output, _ = self.mha(query, key, value)
954+
return attn_output
955+
956+
957+
class WithTransformer(ShapedModule):
958+
"""Module containing a Transformer."""
959+
960+
INPUT_SHAPES = ((10, 8), (20, 8))
961+
OUTPUT_SHAPES = (20, 8)
962+
963+
def __init__(self):
964+
super().__init__()
965+
self.transformer = nn.Transformer(
966+
d_model=8,
967+
nhead=2,
968+
num_encoder_layers=2,
969+
num_decoder_layers=2,
970+
dim_feedforward=32,
971+
batch_first=True,
972+
dropout=0.0,
973+
)
974+
975+
def forward(self, input: tuple[Tensor, Tensor]) -> Tensor:
976+
src, tgt = input
977+
return self.transformer(src, tgt)
978+
979+
980+
class WithTransformerLarge(ShapedModule):
981+
"""Module containing a large Transformer."""
982+
983+
INPUT_SHAPES = ((10, 512), (20, 512))
984+
OUTPUT_SHAPES = (20, 512)
985+
986+
def __init__(self):
987+
super().__init__()
988+
self.transformer = nn.Transformer(
989+
batch_first=True,
990+
dropout=0.0,
991+
)
992+
993+
def forward(self, input: tuple[Tensor, Tensor]) -> Tensor:
994+
src, tgt = input
995+
return self.transformer(src, tgt)
996+
997+
934998
class FreeParam(ShapedModule):
935999
"""
9361000
Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the

0 commit comments

Comments
 (0)