Skip to content

Commit 19839e3

Browse files
committed
Core: Multiple Inputs and Keyword Arguments
- torch 2.0.0 allows us to to pass multiple args and kwargs to hooks - handle multiple inputs and outputs in core.Hook and core.BasicHook, by passing all required grad_outputs and inputs to the backward implementation - BasicHook still only processes a single input - Hook checks the function signature to allow backwards-compatibility TODO: - add tests - add documentation - This stands in conflict with #168, but promises a better implementation by handling inputs and outpus as common to a single function, rather than individually as proposed in #168 - This does not deal with parameter gradients, which are better left to a seperate PR - This will implement #176
1 parent 1cdc001 commit 19839e3

File tree

1 file changed

+115
-37
lines changed

1 file changed

+115
-37
lines changed

src/zennit/core.py

Lines changed: 115 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import weakref
2121
from contextlib import contextmanager
2222
from typing import Generator, Iterator
23+
from itertools import compress, repeat, islice, chain
24+
from inspect import signature
2325

2426
import torch
2527

@@ -235,6 +237,43 @@ def modifier_wrapper(input, name):
235237
return zero_params_wrapper
236238

237239

240+
def uncompress(data, selector, compressed):
241+
'''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
242+
to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
243+
`data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
244+
`compressed`, while `False` values yield one value from `data`.
245+
246+
Parameters
247+
----------
248+
data : iterable
249+
The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
250+
this iterator, while `True` values will cause this iterable to be skipped.
251+
selector : iterable of bool
252+
The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
253+
will be yielded.
254+
compressed : iterable
255+
The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
256+
257+
Yields
258+
------
259+
object
260+
An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
261+
while skipping `data` one ahead.
262+
263+
'''
264+
its = iter(selector)
265+
itc = iter(compressed)
266+
itd = iter(data)
267+
try:
268+
if next(its):
269+
next(itd)
270+
yield next(itc)
271+
else:
272+
yield next(itd)
273+
except StopIteration:
274+
return
275+
276+
238277
class ParamMod:
239278
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
240279
@@ -394,6 +433,7 @@ def forward(ctx, *inputs):
394433
inputs: tuple of :py:obj:`torch.Tensor`
395434
The unmodified inputs.
396435
'''
436+
ctx.mark_non_differentiable(*[elem for elem in inputs if not elem.requires_grad])
397437
return inputs
398438

399439
@staticmethod
@@ -422,7 +462,28 @@ def __init__(self):
422462
self.active = True
423463
self.tensor_handles = RemovableHandleList()
424464

425-
def pre_forward(self, module, input):
465+
@staticmethod
466+
def _inject_grad_fn(args):
467+
tensor_mask = tuple(isinstance(elem, torch.Tensor) for elem in args)
468+
tensors = tuple(compress(args, tensor_mask))
469+
# tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
470+
471+
# only if gradient required
472+
if not any(tensor.requires_grad for tensor in tensors):
473+
return None, args, tensor_mask
474+
475+
# add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
476+
post_tensors = Identity.apply(*tensors)
477+
grad_fn = next((tensor.grad_fn for tensor in post_tensors if tensor.grad_fn is not None), None)
478+
if grad_fn is None:
479+
raise RuntimeError('Backward hook could not be registered!')
480+
481+
# work-around to support in-place operations
482+
post_tensors = tuple(elem.clone() for elem in post_tensors)
483+
post_args = tuple(uncompress(args, tensor_mask, post_tensors))
484+
return grad_fn, post_args, tensor_mask
485+
486+
def pre_forward(self, module, args, kwargs):
426487
'''Apply an Identity to the input before the module to register a backward hook.
427488
428489
Parameters
@@ -440,32 +501,31 @@ def pre_forward(self, module, input):
440501
'''
441502
hook_ref = weakref.ref(self)
442503

504+
grad_fn, post_args, input_tensor_mask = self._inject_grad_fn(args)
505+
if grad_fn is None:
506+
return
507+
443508
@functools.wraps(self.backward)
444509
def wrapper(grad_input, grad_output):
445510
hook = hook_ref()
446511
if hook is not None and hook.active:
447-
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
512+
return hook.backward(
513+
module,
514+
list(uncompress(
515+
repeat(None),
516+
input_tensor_mask,
517+
grad_input,
518+
)),
519+
hook.stored_tensors['grad_output'],
520+
)
448521
return None
449522

450-
if not isinstance(input, tuple):
451-
input = (input,)
523+
# register the input tensor gradient hook
524+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
452525

453-
# only if gradient required
454-
if input[0].requires_grad:
455-
# add identity to ensure .grad_fn exists
456-
post_input = Identity.apply(*input)
457-
# register the input tensor gradient hook
458-
self.tensor_handles.append(
459-
post_input[0].grad_fn.register_hook(wrapper)
460-
)
461-
# work around to support in-place operations
462-
post_input = tuple(elem.clone() for elem in post_input)
463-
else:
464-
# no gradient required
465-
post_input = input
466-
return post_input[0] if len(post_input) == 1 else post_input
526+
return post_args, kwargs
467527

468-
def post_forward(self, module, input, output):
528+
def post_forward(self, module, args, kwargs, output):
469529
'''Register a backward-hook to the resulting tensor right after the forward.
470530
471531
Parameters
@@ -484,23 +544,35 @@ def post_forward(self, module, input, output):
484544
'''
485545
hook_ref = weakref.ref(self)
486546

547+
single = not isinstance(output, tuple)
548+
if single:
549+
output = (output,)
550+
551+
grad_fn, post_output, output_tensor_mask = self._inject_grad_fn(output)
552+
if grad_fn is None:
553+
return
554+
487555
@functools.wraps(self.pre_backward)
488556
def wrapper(grad_input, grad_output):
489557
hook = hook_ref()
490558
if hook is not None and hook.active:
491-
return hook.pre_backward(module, grad_input, grad_output)
559+
return hook.pre_backward(
560+
module,
561+
grad_input,
562+
tuple(uncompress(
563+
repeat(None),
564+
output_tensor_mask,
565+
grad_output
566+
))
567+
)
492568
return None
493569

494-
if not isinstance(output, tuple):
495-
output = (output,)
570+
# register the output tensor gradient hook
571+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
496572

497-
# only if gradient required
498-
if output[0].grad_fn is not None:
499-
# register the output tensor gradient hook
500-
self.tensor_handles.append(
501-
output[0].grad_fn.register_hook(wrapper)
502-
)
503-
return output[0] if len(output) == 1 else output
573+
if single:
574+
return post_output[0]
575+
return post_output
504576

505577
def pre_backward(self, module, grad_input, grad_output):
506578
'''Store the grad_output for the backward hook.
@@ -516,7 +588,7 @@ def pre_backward(self, module, grad_input, grad_output):
516588
'''
517589
self.stored_tensors['grad_output'] = grad_output
518590

519-
def forward(self, module, input, output):
591+
def forward(self, module, args, kwargs, output):
520592
'''Hook applied during forward-pass.
521593
522594
Parameters
@@ -573,11 +645,14 @@ def register(self, module):
573645
A list of removable handles, one for each registered hook.
574646
575647
'''
648+
# assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
649+
forward_params = signature(self.forward).parameters
650+
with_kwargs = len(forward_params) != 3 and list(forward_params)[2] != 'output'
576651
return RemovableHandleList([
577652
RemovableHandle(self),
578-
module.register_forward_pre_hook(self.pre_forward),
579-
module.register_forward_hook(self.post_forward),
580-
module.register_forward_hook(self.forward),
653+
module.register_forward_pre_hook(self.pre_forward, with_kwargs=True),
654+
module.register_forward_hook(self.post_forward, with_kwargs=True),
655+
module.register_forward_hook(self.forward, with_kwargs=with_kwargs),
581656
])
582657

583658

@@ -645,7 +720,7 @@ def __init__(
645720
self.gradient_mapper = gradient_mapper
646721
self.reducer = reducer
647722

648-
def forward(self, module, input, output):
723+
def forward(self, module, args, kwargs, output):
649724
'''Forward hook to save module in-/outputs.
650725
651726
Parameters
@@ -657,7 +732,8 @@ def forward(self, module, input, output):
657732
output: :py:obj:`torch.Tensor`
658733
The output tensor.
659734
'''
660-
self.stored_tensors['input'] = input
735+
self.stored_tensors['input'] = args
736+
self.stored_tensors['kwargs'] = kwargs
661737

662738
def backward(self, module, grad_input, grad_output):
663739
'''Backward hook to compute LRP based on the class attributes.
@@ -676,13 +752,15 @@ def backward(self, module, grad_input, grad_output):
676752
tuple of :py:obj:`torch.nn.Module`
677753
The modified input gradient tensors.
678754
'''
679-
original_input = self.stored_tensors['input'][0].clone()
755+
original_input, *original_args = self.stored_tensors['input']
756+
original_input = original_input.clone()
757+
original_kwargs = self.stored_tensors['kwargs']
680758
inputs = []
681759
outputs = []
682760
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
683761
input = in_mod(original_input).requires_grad_()
684762
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
685-
output = modified.forward(input)
763+
output = modified.forward(input, *original_args, **original_kwargs)
686764
output = out_mod(output)
687765
inputs.append(input)
688766
outputs.append(output)

0 commit comments

Comments
 (0)