Skip to content

Commit 1e52ee7

Browse files
committed
Core: Multiple Inputs and Keyword Arguments
- use additions to forward hooks in torch 2.0.0 to pass kwargs to pass keyword arguments - handle multiple inputs and outputs in core.Hook and core.BasicHook, by passing all required grad_outputs and inputs to the backward implementation TODO: - attribution scores are currently wrong in BasicHook, likely an issue with the gradient inside BasicHook? Might be some cross-terms interacting that should not interact - finish draft and test implementation - add tests - add documentation - This stands in conflict with #168, but promises a better implementation by handling inputs and outpus as common to a single function, rather than individually as proposed in #168
1 parent 1cdc001 commit 1e52ee7

File tree

1 file changed

+152
-47
lines changed

1 file changed

+152
-47
lines changed

src/zennit/core.py

Lines changed: 152 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import weakref
2121
from contextlib import contextmanager
2222
from typing import Generator, Iterator
23+
from itertools import compress, repeat, islice, chain
24+
from inspect import signature
2325

2426
import torch
2527

@@ -235,6 +237,43 @@ def modifier_wrapper(input, name):
235237
return zero_params_wrapper
236238

237239

240+
def uncompress(data, selector, compressed):
241+
'''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
242+
to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
243+
`data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
244+
`compressed`, while `False` values yield one value from `data`.
245+
246+
Parameters
247+
----------
248+
data : iterable
249+
The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
250+
this iterator, while `True` values will cause this iterable to be skipped.
251+
selector : iterable of bool
252+
The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
253+
will be yielded.
254+
compressed : iterable
255+
The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
256+
257+
Yields
258+
------
259+
object
260+
An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
261+
while skipping `data` one ahead.
262+
263+
'''
264+
its = iter(selector)
265+
itc = iter(compressed)
266+
itd = iter(data)
267+
try:
268+
if next(its):
269+
next(itd)
270+
yield next(itc)
271+
else:
272+
yield next(itd)
273+
except StopIteration:
274+
return
275+
276+
238277
class ParamMod:
239278
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
240279
@@ -394,6 +433,7 @@ def forward(ctx, *inputs):
394433
inputs: tuple of :py:obj:`torch.Tensor`
395434
The unmodified inputs.
396435
'''
436+
ctx.mark_non_differentiable(*[elem for elem in inputs if not elem.requires_grad])
397437
return inputs
398438

399439
@staticmethod
@@ -422,7 +462,28 @@ def __init__(self):
422462
self.active = True
423463
self.tensor_handles = RemovableHandleList()
424464

425-
def pre_forward(self, module, input):
465+
@staticmethod
466+
def _inject_grad_fn(args):
467+
tensor_mask = tuple(isinstance(elem, torch.Tensor) for elem in args)
468+
tensors = tuple(compress(args, tensor_mask))
469+
# tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
470+
471+
# only if gradient required
472+
if not any(tensor.requires_grad for tensor in tensors):
473+
return None, args, tensor_mask
474+
475+
# add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
476+
post_tensors = Identity.apply(*tensors)
477+
grad_fn = next((tensor.grad_fn for tensor in post_tensors if tensor.grad_fn is not None), None)
478+
if grad_fn is None:
479+
raise RuntimeError('Backward hook could not be registered!')
480+
481+
# work-around to support in-place operations
482+
post_tensors = tuple(elem.clone() for elem in post_tensors)
483+
post_args = tuple(uncompress(args, tensor_mask, post_tensors))
484+
return grad_fn, post_args, tensor_mask
485+
486+
def pre_forward(self, module, args, kwargs):
426487
'''Apply an Identity to the input before the module to register a backward hook.
427488
428489
Parameters
@@ -440,32 +501,31 @@ def pre_forward(self, module, input):
440501
'''
441502
hook_ref = weakref.ref(self)
442503

504+
grad_fn, post_args, input_tensor_mask = self._inject_grad_fn(args)
505+
if grad_fn is None:
506+
return
507+
443508
@functools.wraps(self.backward)
444509
def wrapper(grad_input, grad_output):
445510
hook = hook_ref()
446511
if hook is not None and hook.active:
447-
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
512+
return hook.backward(
513+
module,
514+
list(uncompress(
515+
repeat(None),
516+
input_tensor_mask,
517+
grad_input,
518+
)),
519+
hook.stored_tensors['grad_output'],
520+
)
448521
return None
449522

450-
if not isinstance(input, tuple):
451-
input = (input,)
523+
# register the input tensor gradient hook
524+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
452525

453-
# only if gradient required
454-
if input[0].requires_grad:
455-
# add identity to ensure .grad_fn exists
456-
post_input = Identity.apply(*input)
457-
# register the input tensor gradient hook
458-
self.tensor_handles.append(
459-
post_input[0].grad_fn.register_hook(wrapper)
460-
)
461-
# work around to support in-place operations
462-
post_input = tuple(elem.clone() for elem in post_input)
463-
else:
464-
# no gradient required
465-
post_input = input
466-
return post_input[0] if len(post_input) == 1 else post_input
526+
return post_args, kwargs
467527

468-
def post_forward(self, module, input, output):
528+
def post_forward(self, module, args, kwargs, output):
469529
'''Register a backward-hook to the resulting tensor right after the forward.
470530
471531
Parameters
@@ -484,23 +544,35 @@ def post_forward(self, module, input, output):
484544
'''
485545
hook_ref = weakref.ref(self)
486546

547+
single = not isinstance(output, tuple)
548+
if single:
549+
output = (output,)
550+
551+
grad_fn, post_output, output_tensor_mask = self._inject_grad_fn(output)
552+
if grad_fn is None:
553+
return
554+
487555
@functools.wraps(self.pre_backward)
488556
def wrapper(grad_input, grad_output):
489557
hook = hook_ref()
490558
if hook is not None and hook.active:
491-
return hook.pre_backward(module, grad_input, grad_output)
559+
return hook.pre_backward(
560+
module,
561+
grad_input,
562+
tuple(uncompress(
563+
repeat(None),
564+
output_tensor_mask,
565+
grad_output
566+
))
567+
)
492568
return None
493569

494-
if not isinstance(output, tuple):
495-
output = (output,)
570+
# register the output tensor gradient hook
571+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
496572

497-
# only if gradient required
498-
if output[0].grad_fn is not None:
499-
# register the output tensor gradient hook
500-
self.tensor_handles.append(
501-
output[0].grad_fn.register_hook(wrapper)
502-
)
503-
return output[0] if len(output) == 1 else output
573+
if single:
574+
return post_output[0]
575+
return post_output
504576

505577
def pre_backward(self, module, grad_input, grad_output):
506578
'''Store the grad_output for the backward hook.
@@ -516,7 +588,7 @@ def pre_backward(self, module, grad_input, grad_output):
516588
'''
517589
self.stored_tensors['grad_output'] = grad_output
518590

519-
def forward(self, module, input, output):
591+
def forward(self, module, args, kwargs, output):
520592
'''Hook applied during forward-pass.
521593
522594
Parameters
@@ -573,11 +645,14 @@ def register(self, module):
573645
A list of removable handles, one for each registered hook.
574646
575647
'''
648+
# assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
649+
forward_params = signature(self.forward).parameters
650+
with_kwargs = len(forward_params) != 3 and list(forward_params)[2] != 'output'
576651
return RemovableHandleList([
577652
RemovableHandle(self),
578-
module.register_forward_pre_hook(self.pre_forward),
579-
module.register_forward_hook(self.post_forward),
580-
module.register_forward_hook(self.forward),
653+
module.register_forward_pre_hook(self.pre_forward, with_kwargs=True),
654+
module.register_forward_hook(self.post_forward, with_kwargs=True),
655+
module.register_forward_hook(self.forward, with_kwargs=with_kwargs),
581656
])
582657

583658

@@ -645,7 +720,7 @@ def __init__(
645720
self.gradient_mapper = gradient_mapper
646721
self.reducer = reducer
647722

648-
def forward(self, module, input, output):
723+
def forward(self, module, args, kwargs, output):
649724
'''Forward hook to save module in-/outputs.
650725
651726
Parameters
@@ -657,7 +732,8 @@ def forward(self, module, input, output):
657732
output: :py:obj:`torch.Tensor`
658733
The output tensor.
659734
'''
660-
self.stored_tensors['input'] = input
735+
self.stored_tensors['input'] = args
736+
self.stored_tensors['kwargs'] = kwargs
661737

662738
def backward(self, module, grad_input, grad_output):
663739
'''Backward hook to compute LRP based on the class attributes.
@@ -676,25 +752,54 @@ def backward(self, module, grad_input, grad_output):
676752
tuple of :py:obj:`torch.nn.Module`
677753
The modified input gradient tensors.
678754
'''
679-
original_input = self.stored_tensors['input'][0].clone()
755+
input_mask = [elem is not None for elem in self.stored_tensors['input']]
756+
output_mask = [elem is not None for elem in grad_output]
757+
cgrad_output = tuple(compress(grad_output, output_mask))
758+
759+
original_inputs = [tensor.clone() for tensor in self.stored_tensors['input']]
760+
kwargs = self.stored_tensors['kwargs']
680761
inputs = []
681762
outputs = []
682763
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
683-
input = in_mod(original_input).requires_grad_()
764+
mod_args = (in_mod(tensor).requires_grad_() for tensor in compress(original_inputs, input_mask))
765+
args = tuple(uncompress(original_inputs, input_mask, mod_args))
684766
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
685-
output = modified.forward(input)
686-
output = out_mod(output)
687-
inputs.append(input)
767+
output = modified.forward(*args, **kwargs)
768+
if not isinstance(output, tuple):
769+
output = (output,)
770+
output = tuple(out_mod(tensor) for tensor in compress(output, output_mask))
771+
inputs.append(compress(args, input_mask))
688772
outputs.append(output)
689-
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
690-
gradients = torch.autograd.grad(
691-
outputs,
692-
inputs,
773+
774+
inputs = list(zip(*inputs))
775+
outputs = list(zip(*outputs))
776+
input_struct = [len(elem) for elem in inputs]
777+
output_struct = [len(elem) for elem in outputs]
778+
779+
grad_outputs = tuple(
780+
self.gradient_mapper(gradout, outs)
781+
for gradout, outs in zip(cgrad_output, outputs)
782+
)
783+
inputs_flat = tuple(chain.from_iterable(inputs))
784+
outputs_flat = tuple(chain.from_iterable(outputs))
785+
if not all(isinstance(elem, torch.Tensor) for elem in grad_outputs):
786+
# if there is only a single output modifier, grad_outputs may contain tensors
787+
grad_outputs = tuple(chain.from_iterable(grad_outputs))
788+
789+
gradients_flat = torch.autograd.grad(
790+
outputs_flat,
791+
inputs_flat,
693792
grad_outputs=grad_outputs,
694-
create_graph=grad_output[0].requires_grad
793+
create_graph=any(tensor.requires_grad for tensor in cgrad_output)
695794
)
696-
relevance = self.reducer(inputs, gradients)
697-
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
795+
796+
# input_it = iter(inputs)
797+
# inputs_re = [tuple(islice(input_it, size)) for size in input_struct]
798+
gradient_it = iter(gradients_flat)
799+
gradients = [tuple(islice(gradient_it, size)) for size in input_struct]
800+
801+
relevances = (self.reducer(inp, grad) for inp, grad in zip(inputs, gradients))
802+
return tuple(uncompress(repeat(None), input_mask, relevances))
698803

699804
def copy(self):
700805
'''Return a copy of this hook.

0 commit comments

Comments
 (0)