2020import weakref
2121from contextlib import contextmanager
2222from typing import Generator , Iterator
23+ from itertools import compress , repeat , islice , chain
24+ from inspect import signature
2325
2426import torch
2527
@@ -235,6 +237,43 @@ def modifier_wrapper(input, name):
235237 return zero_params_wrapper
236238
237239
240+ def uncompress (data , selector , compressed ):
241+ '''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
242+ to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
243+ `data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
244+ `compressed`, while `False` values yield one value from `data`.
245+
246+ Parameters
247+ ----------
248+ data : iterable
249+ The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
250+ this iterator, while `True` values will cause this iterable to be skipped.
251+ selector : iterable of bool
252+ The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
253+ will be yielded.
254+ compressed : iterable
255+ The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
256+
257+ Yields
258+ ------
259+ object
260+ An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
261+ while skipping `data` one ahead.
262+
263+ '''
264+ its = iter (selector )
265+ itc = iter (compressed )
266+ itd = iter (data )
267+ try :
268+ if next (its ):
269+ next (itd )
270+ yield next (itc )
271+ else :
272+ yield next (itd )
273+ except StopIteration :
274+ return
275+
276+
238277class ParamMod :
239278 '''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
240279
@@ -394,6 +433,7 @@ def forward(ctx, *inputs):
394433 inputs: tuple of :py:obj:`torch.Tensor`
395434 The unmodified inputs.
396435 '''
436+ ctx .mark_non_differentiable (* [elem for elem in inputs if not elem .requires_grad ])
397437 return inputs
398438
399439 @staticmethod
@@ -422,7 +462,28 @@ def __init__(self):
422462 self .active = True
423463 self .tensor_handles = RemovableHandleList ()
424464
425- def pre_forward (self , module , input ):
465+ @staticmethod
466+ def _inject_grad_fn (args ):
467+ tensor_mask = tuple (isinstance (elem , torch .Tensor ) for elem in args )
468+ tensors = tuple (compress (args , tensor_mask ))
469+ # tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
470+
471+ # only if gradient required
472+ if not any (tensor .requires_grad for tensor in tensors ):
473+ return None , args , tensor_mask
474+
475+ # add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
476+ post_tensors = Identity .apply (* tensors )
477+ grad_fn = next ((tensor .grad_fn for tensor in post_tensors if tensor .grad_fn is not None ), None )
478+ if grad_fn is None :
479+ raise RuntimeError ('Backward hook could not be registered!' )
480+
481+ # work-around to support in-place operations
482+ post_tensors = tuple (elem .clone () for elem in post_tensors )
483+ post_args = tuple (uncompress (args , tensor_mask , post_tensors ))
484+ return grad_fn , post_args , tensor_mask
485+
486+ def pre_forward (self , module , args , kwargs ):
426487 '''Apply an Identity to the input before the module to register a backward hook.
427488
428489 Parameters
@@ -440,32 +501,31 @@ def pre_forward(self, module, input):
440501 '''
441502 hook_ref = weakref .ref (self )
442503
504+ grad_fn , post_args , input_tensor_mask = self ._inject_grad_fn (args )
505+ if grad_fn is None :
506+ return
507+
443508 @functools .wraps (self .backward )
444509 def wrapper (grad_input , grad_output ):
445510 hook = hook_ref ()
446511 if hook is not None and hook .active :
447- return hook .backward (module , grad_input , hook .stored_tensors ['grad_output' ])
512+ return hook .backward (
513+ module ,
514+ list (uncompress (
515+ repeat (None ),
516+ input_tensor_mask ,
517+ grad_input ,
518+ )),
519+ hook .stored_tensors ['grad_output' ],
520+ )
448521 return None
449522
450- if not isinstance ( input , tuple ):
451- input = ( input , )
523+ # register the input tensor gradient hook
524+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
452525
453- # only if gradient required
454- if input [0 ].requires_grad :
455- # add identity to ensure .grad_fn exists
456- post_input = Identity .apply (* input )
457- # register the input tensor gradient hook
458- self .tensor_handles .append (
459- post_input [0 ].grad_fn .register_hook (wrapper )
460- )
461- # work around to support in-place operations
462- post_input = tuple (elem .clone () for elem in post_input )
463- else :
464- # no gradient required
465- post_input = input
466- return post_input [0 ] if len (post_input ) == 1 else post_input
526+ return post_args , kwargs
467527
468- def post_forward (self , module , input , output ):
528+ def post_forward (self , module , args , kwargs , output ):
469529 '''Register a backward-hook to the resulting tensor right after the forward.
470530
471531 Parameters
@@ -484,23 +544,35 @@ def post_forward(self, module, input, output):
484544 '''
485545 hook_ref = weakref .ref (self )
486546
547+ single = not isinstance (output , tuple )
548+ if single :
549+ output = (output ,)
550+
551+ grad_fn , post_output , output_tensor_mask = self ._inject_grad_fn (output )
552+ if grad_fn is None :
553+ return
554+
487555 @functools .wraps (self .pre_backward )
488556 def wrapper (grad_input , grad_output ):
489557 hook = hook_ref ()
490558 if hook is not None and hook .active :
491- return hook .pre_backward (module , grad_input , grad_output )
559+ return hook .pre_backward (
560+ module ,
561+ grad_input ,
562+ tuple (uncompress (
563+ repeat (None ),
564+ output_tensor_mask ,
565+ grad_output
566+ ))
567+ )
492568 return None
493569
494- if not isinstance ( output , tuple ):
495- output = ( output , )
570+ # register the output tensor gradient hook
571+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
496572
497- # only if gradient required
498- if output [0 ].grad_fn is not None :
499- # register the output tensor gradient hook
500- self .tensor_handles .append (
501- output [0 ].grad_fn .register_hook (wrapper )
502- )
503- return output [0 ] if len (output ) == 1 else output
573+ if single :
574+ return post_output [0 ]
575+ return post_output
504576
505577 def pre_backward (self , module , grad_input , grad_output ):
506578 '''Store the grad_output for the backward hook.
@@ -516,7 +588,7 @@ def pre_backward(self, module, grad_input, grad_output):
516588 '''
517589 self .stored_tensors ['grad_output' ] = grad_output
518590
519- def forward (self , module , input , output ):
591+ def forward (self , module , args , kwargs , output ):
520592 '''Hook applied during forward-pass.
521593
522594 Parameters
@@ -573,11 +645,14 @@ def register(self, module):
573645 A list of removable handles, one for each registered hook.
574646
575647 '''
648+ # assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
649+ forward_params = signature (self .forward ).parameters
650+ with_kwargs = len (forward_params ) != 3 and list (forward_params )[2 ] != 'output'
576651 return RemovableHandleList ([
577652 RemovableHandle (self ),
578- module .register_forward_pre_hook (self .pre_forward ),
579- module .register_forward_hook (self .post_forward ),
580- module .register_forward_hook (self .forward ),
653+ module .register_forward_pre_hook (self .pre_forward , with_kwargs = True ),
654+ module .register_forward_hook (self .post_forward , with_kwargs = True ),
655+ module .register_forward_hook (self .forward , with_kwargs = with_kwargs ),
581656 ])
582657
583658
@@ -645,7 +720,7 @@ def __init__(
645720 self .gradient_mapper = gradient_mapper
646721 self .reducer = reducer
647722
648- def forward (self , module , input , output ):
723+ def forward (self , module , args , kwargs , output ):
649724 '''Forward hook to save module in-/outputs.
650725
651726 Parameters
@@ -657,7 +732,8 @@ def forward(self, module, input, output):
657732 output: :py:obj:`torch.Tensor`
658733 The output tensor.
659734 '''
660- self .stored_tensors ['input' ] = input
735+ self .stored_tensors ['input' ] = args
736+ self .stored_tensors ['kwargs' ] = kwargs
661737
662738 def backward (self , module , grad_input , grad_output ):
663739 '''Backward hook to compute LRP based on the class attributes.
@@ -676,25 +752,54 @@ def backward(self, module, grad_input, grad_output):
676752 tuple of :py:obj:`torch.nn.Module`
677753 The modified input gradient tensors.
678754 '''
679- original_input = self .stored_tensors ['input' ][0 ].clone ()
755+ input_mask = [elem is not None for elem in self .stored_tensors ['input' ]]
756+ output_mask = [elem is not None for elem in grad_output ]
757+ cgrad_output = tuple (compress (grad_output , output_mask ))
758+
759+ original_inputs = [tensor .clone () for tensor in self .stored_tensors ['input' ]]
760+ kwargs = self .stored_tensors ['kwargs' ]
680761 inputs = []
681762 outputs = []
682763 for in_mod , param_mod , out_mod in zip (self .input_modifiers , self .param_modifiers , self .output_modifiers ):
683- input = in_mod (original_input ).requires_grad_ ()
764+ mod_args = (in_mod (tensor ).requires_grad_ () for tensor in compress (original_inputs , input_mask ))
765+ args = tuple (uncompress (original_inputs , input_mask , mod_args ))
684766 with ParamMod .ensure (param_mod )(module ) as modified , torch .autograd .enable_grad ():
685- output = modified .forward (input )
686- output = out_mod (output )
687- inputs .append (input )
767+ output = modified .forward (* args , ** kwargs )
768+ if not isinstance (output , tuple ):
769+ output = (output ,)
770+ output = tuple (out_mod (tensor ) for tensor in compress (output , output_mask ))
771+ inputs .append (compress (args , input_mask ))
688772 outputs .append (output )
689- grad_outputs = self .gradient_mapper (grad_output [0 ], outputs )
690- gradients = torch .autograd .grad (
691- outputs ,
692- inputs ,
773+
774+ inputs = list (zip (* inputs ))
775+ outputs = list (zip (* outputs ))
776+ input_struct = [len (elem ) for elem in inputs ]
777+ output_struct = [len (elem ) for elem in outputs ]
778+
779+ grad_outputs = tuple (
780+ self .gradient_mapper (gradout , outs )
781+ for gradout , outs in zip (cgrad_output , outputs )
782+ )
783+ inputs_flat = tuple (chain .from_iterable (inputs ))
784+ outputs_flat = tuple (chain .from_iterable (outputs ))
785+ if not all (isinstance (elem , torch .Tensor ) for elem in grad_outputs ):
786+ # if there is only a single output modifier, grad_outputs may contain tensors
787+ grad_outputs = tuple (chain .from_iterable (grad_outputs ))
788+
789+ gradients_flat = torch .autograd .grad (
790+ outputs_flat ,
791+ inputs_flat ,
693792 grad_outputs = grad_outputs ,
694- create_graph = grad_output [ 0 ] .requires_grad
793+ create_graph = any ( tensor .requires_grad for tensor in cgrad_output )
695794 )
696- relevance = self .reducer (inputs , gradients )
697- return tuple (relevance if original .shape == relevance .shape else None for original in grad_input )
795+
796+ # input_it = iter(inputs)
797+ # inputs_re = [tuple(islice(input_it, size)) for size in input_struct]
798+ gradient_it = iter (gradients_flat )
799+ gradients = [tuple (islice (gradient_it , size )) for size in input_struct ]
800+
801+ relevances = (self .reducer (inp , grad ) for inp , grad in zip (inputs , gradients ))
802+ return tuple (uncompress (repeat (None ), input_mask , relevances ))
698803
699804 def copy (self ):
700805 '''Return a copy of this hook.
0 commit comments