2020import weakref
2121from contextlib import contextmanager
2222from typing import Generator , Iterator
23+ from itertools import compress , repeat
24+ from inspect import signature
2325
2426import torch
2527
@@ -235,6 +237,44 @@ def modifier_wrapper(input, name):
235237 return zero_params_wrapper
236238
237239
240+ def uncompress (data , selector , compressed ) -> Generator :
241+ '''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
242+ to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
243+ `data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
244+ `compressed`, while `False` values yield one value from `data`.
245+
246+ Parameters
247+ ----------
248+ data : iterable
249+ The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
250+ this iterator, while `True` values will cause this iterable to be skipped.
251+ selector : iterable of bool
252+ The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
253+ will be yielded.
254+ compressed : iterable
255+ The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
256+
257+ Yields
258+ ------
259+ object
260+ An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
261+ while skipping `data` one ahead.
262+
263+ '''
264+ its = iter (selector )
265+ itc = iter (compressed )
266+ itd = iter (data )
267+ for select in its :
268+ try :
269+ if select :
270+ next (itd )
271+ yield next (itc )
272+ else :
273+ yield next (itd )
274+ except StopIteration :
275+ break
276+
277+
238278class ParamMod :
239279 '''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
240280
@@ -394,6 +434,7 @@ def forward(ctx, *inputs):
394434 inputs: tuple of :py:obj:`torch.Tensor`
395435 The unmodified inputs.
396436 '''
437+ ctx .mark_non_differentiable (* [elem for elem in inputs if not elem .requires_grad ])
397438 return inputs
398439
399440 @staticmethod
@@ -422,15 +463,41 @@ def __init__(self):
422463 self .active = True
423464 self .tensor_handles = RemovableHandleList ()
424465
425- def pre_forward (self , module , input ):
466+ @staticmethod
467+ def _inject_grad_fn (args ):
468+ tensor_mask = tuple (isinstance (elem , torch .Tensor ) for elem in args )
469+ tensors = tuple (compress (args , tensor_mask ))
470+ # tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
471+
472+ # only if gradient required
473+ if not any (tensor .requires_grad for tensor in tensors ):
474+ return None , args , tensor_mask
475+
476+ # add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
477+ post_tensors = Identity .apply (* tensors )
478+ grad_fn = next ((tensor .grad_fn for tensor in post_tensors if tensor .grad_fn is not None ), None )
479+ if grad_fn is None :
480+ # sanity check, should never happen because the check above already catches cases in which no input tensor
481+ # requires a gradient, and in normal conditions, we will always obtain a grad_fn from `Identity` for each
482+ # tensor with requires_grad=True
483+ raise RuntimeError ('Backward hook could not be registered!' ) # pragma: no cover
484+
485+ # work-around to support in-place operations
486+ post_tensors = tuple (elem .clone () for elem in post_tensors )
487+ post_args = tuple (uncompress (args , tensor_mask , post_tensors ))
488+ return grad_fn , post_args , tensor_mask
489+
490+ def pre_forward (self , module , args , kwargs ):
426491 '''Apply an Identity to the input before the module to register a backward hook.
427492
428493 Parameters
429494 ----------
430495 module: :py:obj:`torch.nn.Module`
431496 The module to which this hook is attached.
432- input: :py:obj:`torch.Tensor`
433- The input tensor.
497+ args: tuple of :py:obj:`torch.Tensor`
498+ The input tensors passed to ``module.forward``.
499+ kwargs: dict
500+ The keyword arguments passed to ``module.forward``.
434501
435502 Returns
436503 -------
@@ -440,40 +507,41 @@ def pre_forward(self, module, input):
440507 '''
441508 hook_ref = weakref .ref (self )
442509
510+ grad_fn , post_args , input_tensor_mask = self ._inject_grad_fn (args )
511+ if grad_fn is None :
512+ return None
513+
443514 @functools .wraps (self .backward )
444515 def wrapper (grad_input , grad_output ):
445516 hook = hook_ref ()
446517 if hook is not None and hook .active :
447- return hook .backward (module , grad_input , hook .stored_tensors ['grad_output' ])
518+ return hook .backward (
519+ module ,
520+ list (uncompress (
521+ repeat (None ),
522+ input_tensor_mask ,
523+ grad_input ,
524+ )),
525+ hook .stored_tensors ['grad_output' ],
526+ )
448527 return None
449528
450- if not isinstance ( input , tuple ):
451- input = ( input , )
529+ # register the input tensor gradient hook
530+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
452531
453- # only if gradient required
454- if input [0 ].requires_grad :
455- # add identity to ensure .grad_fn exists
456- post_input = Identity .apply (* input )
457- # register the input tensor gradient hook
458- self .tensor_handles .append (
459- post_input [0 ].grad_fn .register_hook (wrapper )
460- )
461- # work around to support in-place operations
462- post_input = tuple (elem .clone () for elem in post_input )
463- else :
464- # no gradient required
465- post_input = input
466- return post_input [0 ] if len (post_input ) == 1 else post_input
467-
468- def post_forward (self , module , input , output ):
532+ return post_args , kwargs
533+
534+ def post_forward (self , module , args , kwargs , output ):
469535 '''Register a backward-hook to the resulting tensor right after the forward.
470536
471537 Parameters
472538 ----------
473539 module: :py:obj:`torch.nn.Module`
474540 The module to which this hook is attached.
475- input: :py:obj:`torch.Tensor`
476- The input tensor.
541+ args: tuple of :py:obj:`torch.Tensor`
542+ The input tensors passed to ``module.forward``.
543+ kwargs: tuple of object
544+ The keyword arguments passed to ``module.forward``.
477545 output: :py:obj:`torch.Tensor`
478546 The output tensor.
479547
@@ -484,23 +552,35 @@ def post_forward(self, module, input, output):
484552 '''
485553 hook_ref = weakref .ref (self )
486554
555+ single = not isinstance (output , tuple )
556+ if single :
557+ output = (output ,)
558+
559+ grad_fn , post_output , output_tensor_mask = self ._inject_grad_fn (output )
560+ if grad_fn is None :
561+ return None
562+
487563 @functools .wraps (self .pre_backward )
488564 def wrapper (grad_input , grad_output ):
489565 hook = hook_ref ()
490566 if hook is not None and hook .active :
491- return hook .pre_backward (module , grad_input , grad_output )
567+ return hook .pre_backward (
568+ module ,
569+ grad_input ,
570+ tuple (uncompress (
571+ repeat (None ),
572+ output_tensor_mask ,
573+ grad_output
574+ ))
575+ )
492576 return None
493577
494- if not isinstance ( output , tuple ):
495- output = ( output , )
578+ # register the output tensor gradient hook
579+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
496580
497- # only if gradient required
498- if output [0 ].grad_fn is not None :
499- # register the output tensor gradient hook
500- self .tensor_handles .append (
501- output [0 ].grad_fn .register_hook (wrapper )
502- )
503- return output [0 ] if len (output ) == 1 else output
581+ if single :
582+ return post_output [0 ]
583+ return post_output
504584
505585 def pre_backward (self , module , grad_input , grad_output ):
506586 '''Store the grad_output for the backward hook.
@@ -516,15 +596,17 @@ def pre_backward(self, module, grad_input, grad_output):
516596 '''
517597 self .stored_tensors ['grad_output' ] = grad_output
518598
519- def forward (self , module , input , output ):
599+ def forward (self , module , args , kwargs , output ):
520600 '''Hook applied during forward-pass.
521601
522602 Parameters
523603 ----------
524604 module: :py:obj:`torch.nn.Module`
525605 The module to which this hook is attached.
526- input: :py:obj:`torch.Tensor`
527- The input tensor.
606+ args: tuple of :py:obj:`torch.Tensor`
607+ The input tensors passed to ``module.forward``.
608+ kwargs: tuple of object
609+ The keyword arguments passed to ``module.forward``.
528610 output: :py:obj:`torch.Tensor`
529611 The output tensor.
530612 '''
@@ -573,11 +655,34 @@ def register(self, module):
573655 A list of removable handles, one for each registered hook.
574656
575657 '''
658+ def with_kwargs (method , has_output = True ):
659+ '''Check whether the method uses args/kwargs, or only inputs. This ensures compatibility with rules that do
660+ not consider kwargs, and reduces code clutter.
661+
662+ Parameters
663+ ----------
664+ method: function
665+ Function to check.
666+ has_output: bool
667+ Function to check.
668+
669+ Returns
670+ -------
671+ bool
672+ True if `method` uses kwargs.
673+ '''
674+ params = signature (method ).parameters
675+ # assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
676+ if has_output :
677+ return len (params ) != 3 and list (params )[2 ] != 'output'
678+ # e.g., pre_forward has no output, so we expect 2 parameters
679+ return len (params ) != 2
680+
576681 return RemovableHandleList ([
577682 RemovableHandle (self ),
578- module .register_forward_pre_hook (self .pre_forward ),
579- module .register_forward_hook (self .post_forward ),
580- module .register_forward_hook (self .forward ),
683+ module .register_forward_pre_hook (self .pre_forward , with_kwargs = with_kwargs ( self . pre_forward , False ) ),
684+ module .register_forward_hook (self .post_forward , with_kwargs = with_kwargs ( self . post_forward ) ),
685+ module .register_forward_hook (self .forward , with_kwargs = with_kwargs ( self . forward ) ),
581686 ])
582687
583688
@@ -645,19 +750,22 @@ def __init__(
645750 self .gradient_mapper = gradient_mapper
646751 self .reducer = reducer
647752
648- def forward (self , module , input , output ):
753+ def forward (self , module , args , kwargs , output ):
649754 '''Forward hook to save module in-/outputs.
650755
651756 Parameters
652757 ----------
653758 module: :py:obj:`torch.nn.Module`
654759 The module to which this hook is attached.
655- input: :py:obj:`torch.Tensor`
656- The input tensor.
760+ args: tuple of :py:obj:`torch.Tensor`
761+ The input tensors passed to ``module.forward``.
762+ kwargs: tuple of object
763+ The keyword arguments passed to ``module.forward``.
657764 output: :py:obj:`torch.Tensor`
658765 The output tensor.
659766 '''
660- self .stored_tensors ['input' ] = input
767+ self .stored_tensors ['input' ] = args
768+ self .stored_tensors ['kwargs' ] = kwargs
661769
662770 def backward (self , module , grad_input , grad_output ):
663771 '''Backward hook to compute LRP based on the class attributes.
@@ -676,13 +784,15 @@ def backward(self, module, grad_input, grad_output):
676784 tuple of :py:obj:`torch.nn.Module`
677785 The modified input gradient tensors.
678786 '''
679- original_input = self .stored_tensors ['input' ][0 ].clone ()
787+ original_input , * original_args = self .stored_tensors ['input' ]
788+ original_input = original_input .clone ()
789+ original_kwargs = self .stored_tensors ['kwargs' ]
680790 inputs = []
681791 outputs = []
682792 for in_mod , param_mod , out_mod in zip (self .input_modifiers , self .param_modifiers , self .output_modifiers ):
683793 input = in_mod (original_input ).requires_grad_ ()
684794 with ParamMod .ensure (param_mod )(module ) as modified , torch .autograd .enable_grad ():
685- output = modified .forward (input )
795+ output = modified .forward (input , * original_args , ** original_kwargs )
686796 output = out_mod (output )
687797 inputs .append (input )
688798 outputs .append (output )
0 commit comments