Skip to content

Commit c4fbdaf

Browse files
committed
Core: Multiple Inputs and Keyword Arguments
- torch 2.0.0 allows us to to pass multiple args and kwargs to hooks - handle multiple inputs and outputs in core.Hook and core.BasicHook, by passing all required grad_outputs and inputs to the backward implementation - BasicHook still only processes a single input - Hook checks the function signature to allow backwards-compatibility - added a basic test that uses the kwargs-signature - added a note that keyword arguments are supported in the documentation Notes: - This stands in conflict with #168, but promises a better implementation by handling inputs and outpus as common to a single function, rather than individually as proposed in #168 - This does not deal with parameter gradients, which are better left to a seperate PR implements #176
1 parent 4b73ea4 commit c4fbdaf

File tree

5 files changed

+224
-60
lines changed

5 files changed

+224
-60
lines changed

docs/source/how-to/write-custom-rules.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,23 @@ for the layer-wise relevance propagation (LRP)-based **Composites**, used for
8585
all activations.
8686
:py:class:`~zennit.core.Hook` has a dictionary attribute ``stored_tensors``,
8787
which is used to store the output gradient as ``stored_tensors['grad_output']``.
88-
:py:meth:`~zennit.core.Hook.forward` has 3 arguments:
88+
:py:meth:`~zennit.core.Hook.forward` can have two shapes, one with keyword-arguments, and one without.
89+
If the rule does not need to handle keyword arguments:
8990

9091
* ``module``, which is the current module the hook has been registered to,
91-
* ``input``, which is the module's input tensor, and
92-
* ``output``, which is the module's output tensor.
92+
* ``input``, which are the module's input tensors, and
93+
* ``output``, which are the module's output tensors.
94+
95+
If the rule should also handle keyword arguments (new in version 1.0.0), the following signature may be used:
96+
97+
* ``module``, which is the current module the hook has been registered to,
98+
* ``args``, which are the module's positional inputs (mixed tensors and parameters allowed),
99+
* ``kwargs``, which are the module's keyword inputs (tensors unsupported), and
100+
* ``output``, which are the module's output tensors.
93101

94102
:py:meth:`~zennit.core.Hook.forward` is always called *after* the forward has
95103
been called, thus making ``output`` available.
96-
Using the notation above, ``input`` is :math:`x` and ``output`` is :math:`f(x)`.
104+
Using the first notation above, ``input`` is :math:`x` and ``output`` is :math:`f(x)`.
97105

98106
A layer-wise *gradient times input* can be implemented by storing the input
99107
tensor in the forward pass and directly using ``grad_input`` in the backward

src/zennit/core.py

Lines changed: 156 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import weakref
2121
from contextlib import contextmanager
2222
from typing import Generator, Iterator
23+
from itertools import compress, repeat
24+
from inspect import signature
2325

2426
import torch
2527

@@ -235,6 +237,44 @@ def modifier_wrapper(input, name):
235237
return zero_params_wrapper
236238

237239

240+
def uncompress(data, selector, compressed) -> Generator:
241+
'''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
242+
to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
243+
`data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
244+
`compressed`, while `False` values yield one value from `data`.
245+
246+
Parameters
247+
----------
248+
data : iterable
249+
The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
250+
this iterator, while `True` values will cause this iterable to be skipped.
251+
selector : iterable of bool
252+
The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
253+
will be yielded.
254+
compressed : iterable
255+
The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
256+
257+
Yields
258+
------
259+
object
260+
An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
261+
while skipping `data` one ahead.
262+
263+
'''
264+
its = iter(selector)
265+
itc = iter(compressed)
266+
itd = iter(data)
267+
for select in its:
268+
try:
269+
if select:
270+
next(itd)
271+
yield next(itc)
272+
else:
273+
yield next(itd)
274+
except StopIteration:
275+
break
276+
277+
238278
class ParamMod:
239279
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
240280
@@ -394,6 +434,7 @@ def forward(ctx, *inputs):
394434
inputs: tuple of :py:obj:`torch.Tensor`
395435
The unmodified inputs.
396436
'''
437+
ctx.mark_non_differentiable(*[elem for elem in inputs if not elem.requires_grad])
397438
return inputs
398439

399440
@staticmethod
@@ -422,15 +463,41 @@ def __init__(self):
422463
self.active = True
423464
self.tensor_handles = RemovableHandleList()
424465

425-
def pre_forward(self, module, input):
466+
@staticmethod
467+
def _inject_grad_fn(args):
468+
tensor_mask = tuple(isinstance(elem, torch.Tensor) for elem in args)
469+
tensors = tuple(compress(args, tensor_mask))
470+
# tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
471+
472+
# only if gradient required
473+
if not any(tensor.requires_grad for tensor in tensors):
474+
return None, args, tensor_mask
475+
476+
# add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
477+
post_tensors = Identity.apply(*tensors)
478+
grad_fn = next((tensor.grad_fn for tensor in post_tensors if tensor.grad_fn is not None), None)
479+
if grad_fn is None:
480+
# sanity check, should never happen because the check above already catches cases in which no input tensor
481+
# requires a gradient, and in normal conditions, we will always obtain a grad_fn from `Identity` for each
482+
# tensor with requires_grad=True
483+
raise RuntimeError('Backward hook could not be registered!') # pragma: no cover
484+
485+
# work-around to support in-place operations
486+
post_tensors = tuple(elem.clone() for elem in post_tensors)
487+
post_args = tuple(uncompress(args, tensor_mask, post_tensors))
488+
return grad_fn, post_args, tensor_mask
489+
490+
def pre_forward(self, module, args, kwargs):
426491
'''Apply an Identity to the input before the module to register a backward hook.
427492
428493
Parameters
429494
----------
430495
module: :py:obj:`torch.nn.Module`
431496
The module to which this hook is attached.
432-
input: :py:obj:`torch.Tensor`
433-
The input tensor.
497+
args: tuple of :py:obj:`torch.Tensor`
498+
The input tensors passed to ``module.forward``.
499+
kwargs: dict
500+
The keyword arguments passed to ``module.forward``.
434501
435502
Returns
436503
-------
@@ -440,40 +507,41 @@ def pre_forward(self, module, input):
440507
'''
441508
hook_ref = weakref.ref(self)
442509

510+
grad_fn, post_args, input_tensor_mask = self._inject_grad_fn(args)
511+
if grad_fn is None:
512+
return None
513+
443514
@functools.wraps(self.backward)
444515
def wrapper(grad_input, grad_output):
445516
hook = hook_ref()
446517
if hook is not None and hook.active:
447-
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
518+
return hook.backward(
519+
module,
520+
list(uncompress(
521+
repeat(None),
522+
input_tensor_mask,
523+
grad_input,
524+
)),
525+
hook.stored_tensors['grad_output'],
526+
)
448527
return None
449528

450-
if not isinstance(input, tuple):
451-
input = (input,)
529+
# register the input tensor gradient hook
530+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
452531

453-
# only if gradient required
454-
if input[0].requires_grad:
455-
# add identity to ensure .grad_fn exists
456-
post_input = Identity.apply(*input)
457-
# register the input tensor gradient hook
458-
self.tensor_handles.append(
459-
post_input[0].grad_fn.register_hook(wrapper)
460-
)
461-
# work around to support in-place operations
462-
post_input = tuple(elem.clone() for elem in post_input)
463-
else:
464-
# no gradient required
465-
post_input = input
466-
return post_input[0] if len(post_input) == 1 else post_input
467-
468-
def post_forward(self, module, input, output):
532+
return post_args, kwargs
533+
534+
def post_forward(self, module, args, kwargs, output):
469535
'''Register a backward-hook to the resulting tensor right after the forward.
470536
471537
Parameters
472538
----------
473539
module: :py:obj:`torch.nn.Module`
474540
The module to which this hook is attached.
475-
input: :py:obj:`torch.Tensor`
476-
The input tensor.
541+
args: tuple of :py:obj:`torch.Tensor`
542+
The input tensors passed to ``module.forward``.
543+
kwargs: tuple of object
544+
The keyword arguments passed to ``module.forward``.
477545
output: :py:obj:`torch.Tensor`
478546
The output tensor.
479547
@@ -484,23 +552,35 @@ def post_forward(self, module, input, output):
484552
'''
485553
hook_ref = weakref.ref(self)
486554

555+
single = not isinstance(output, tuple)
556+
if single:
557+
output = (output,)
558+
559+
grad_fn, post_output, output_tensor_mask = self._inject_grad_fn(output)
560+
if grad_fn is None:
561+
return None
562+
487563
@functools.wraps(self.pre_backward)
488564
def wrapper(grad_input, grad_output):
489565
hook = hook_ref()
490566
if hook is not None and hook.active:
491-
return hook.pre_backward(module, grad_input, grad_output)
567+
return hook.pre_backward(
568+
module,
569+
grad_input,
570+
tuple(uncompress(
571+
repeat(None),
572+
output_tensor_mask,
573+
grad_output
574+
))
575+
)
492576
return None
493577

494-
if not isinstance(output, tuple):
495-
output = (output,)
578+
# register the output tensor gradient hook
579+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
496580

497-
# only if gradient required
498-
if output[0].grad_fn is not None:
499-
# register the output tensor gradient hook
500-
self.tensor_handles.append(
501-
output[0].grad_fn.register_hook(wrapper)
502-
)
503-
return output[0] if len(output) == 1 else output
581+
if single:
582+
return post_output[0]
583+
return post_output
504584

505585
def pre_backward(self, module, grad_input, grad_output):
506586
'''Store the grad_output for the backward hook.
@@ -516,15 +596,17 @@ def pre_backward(self, module, grad_input, grad_output):
516596
'''
517597
self.stored_tensors['grad_output'] = grad_output
518598

519-
def forward(self, module, input, output):
599+
def forward(self, module, args, kwargs, output):
520600
'''Hook applied during forward-pass.
521601
522602
Parameters
523603
----------
524604
module: :py:obj:`torch.nn.Module`
525605
The module to which this hook is attached.
526-
input: :py:obj:`torch.Tensor`
527-
The input tensor.
606+
args: tuple of :py:obj:`torch.Tensor`
607+
The input tensors passed to ``module.forward``.
608+
kwargs: tuple of object
609+
The keyword arguments passed to ``module.forward``.
528610
output: :py:obj:`torch.Tensor`
529611
The output tensor.
530612
'''
@@ -573,11 +655,34 @@ def register(self, module):
573655
A list of removable handles, one for each registered hook.
574656
575657
'''
658+
def with_kwargs(method, has_output=True):
659+
'''Check whether the method uses args/kwargs, or only inputs. This ensures compatibility with rules that do
660+
not consider kwargs, and reduces code clutter.
661+
662+
Parameters
663+
----------
664+
method: function
665+
Function to check.
666+
has_output: bool
667+
Function to check.
668+
669+
Returns
670+
-------
671+
bool
672+
True if `method` uses kwargs.
673+
'''
674+
params = signature(method).parameters
675+
# assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
676+
if has_output:
677+
return len(params) != 3 and list(params)[2] != 'output'
678+
# e.g., pre_forward has no output, so we expect 2 parameters
679+
return len(params) != 2
680+
576681
return RemovableHandleList([
577682
RemovableHandle(self),
578-
module.register_forward_pre_hook(self.pre_forward),
579-
module.register_forward_hook(self.post_forward),
580-
module.register_forward_hook(self.forward),
683+
module.register_forward_pre_hook(self.pre_forward, with_kwargs=with_kwargs(self.pre_forward, False)),
684+
module.register_forward_hook(self.post_forward, with_kwargs=with_kwargs(self.post_forward)),
685+
module.register_forward_hook(self.forward, with_kwargs=with_kwargs(self.forward)),
581686
])
582687

583688

@@ -645,19 +750,22 @@ def __init__(
645750
self.gradient_mapper = gradient_mapper
646751
self.reducer = reducer
647752

648-
def forward(self, module, input, output):
753+
def forward(self, module, args, kwargs, output):
649754
'''Forward hook to save module in-/outputs.
650755
651756
Parameters
652757
----------
653758
module: :py:obj:`torch.nn.Module`
654759
The module to which this hook is attached.
655-
input: :py:obj:`torch.Tensor`
656-
The input tensor.
760+
args: tuple of :py:obj:`torch.Tensor`
761+
The input tensors passed to ``module.forward``.
762+
kwargs: tuple of object
763+
The keyword arguments passed to ``module.forward``.
657764
output: :py:obj:`torch.Tensor`
658765
The output tensor.
659766
'''
660-
self.stored_tensors['input'] = input
767+
self.stored_tensors['input'] = args
768+
self.stored_tensors['kwargs'] = kwargs
661769

662770
def backward(self, module, grad_input, grad_output):
663771
'''Backward hook to compute LRP based on the class attributes.
@@ -676,13 +784,15 @@ def backward(self, module, grad_input, grad_output):
676784
tuple of :py:obj:`torch.nn.Module`
677785
The modified input gradient tensors.
678786
'''
679-
original_input = self.stored_tensors['input'][0].clone()
787+
original_input, *original_args = self.stored_tensors['input']
788+
original_input = original_input.clone()
789+
original_kwargs = self.stored_tensors['kwargs']
680790
inputs = []
681791
outputs = []
682792
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
683793
input = in_mod(original_input).requires_grad_()
684794
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
685-
output = modified.forward(input)
795+
output = modified.forward(input, *original_args, **original_kwargs)
686796
output = out_mod(output)
687797
inputs.append(input)
688798
outputs.append(output)

tests/unit/.pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ testpaths = tests/unit
1010
# with output, (a) all except passed (p/P), or (A) all
1111
# --showlocals: Show local variables in tracebacks
1212
addopts = -ra --showlocals
13+
14+
markers =
15+
extended: do tests with multiple seeds (deselect with '-m "not extended"')

tests/unit/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def pytest_generate_tests(metafunc):
4343
scope='session',
4444
params=[
4545
0xdeadbeef,
46-
0xd0c0ffee,
4746
*[pytest.param(seed, marks=pytest.mark.extended) for seed in [
48-
0xc001bee5, 0xc01dfee7, 0xbe577001, 0xca7b0075, 0x1057b0a7, 0x900ddeed
47+
0xd0c0ffee, 0xc001bee5, 0xc01dfee7, 0xbe577001, 0xca7b0075, 0x1057b0a7, 0x900ddeed
4948
]],
5049
],
5150
ids=hex
@@ -261,7 +260,7 @@ def partial_name_map_composite(name_map_composite, pyrng):
261260

262261
@pytest.fixture(scope='session')
263262
def mixed_composite(partial_name_map_composite, special_first_layer_map_composite):
264-
'''Fixture to create NameLayerMapComposites based on an explicit NameMapComposite and
263+
'''Fixture to create mixtures of explicit NameMapComposite and
265264
SpecialFirstLayerMapComposites.
266265
'''
267266
composites = [partial_name_map_composite, special_first_layer_map_composite]

0 commit comments

Comments
 (0)