1919import functools
2020import weakref
2121from contextlib import contextmanager
22+ from itertools import compress , repeat , islice , chain
23+ from inspect import signature
2224
2325import torch
2426
@@ -234,6 +236,43 @@ def modifier_wrapper(input, name):
234236 return zero_params_wrapper
235237
236238
239+ def uncompress (data , selector , compressed ):
240+ '''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
241+ to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
242+ `data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
243+ `compressed`, while `False` values yield one value from `data`.
244+
245+ Parameters
246+ ----------
247+ data : iterable
248+ The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
249+ this iterator, while `True` values will cause this iterable to be skipped.
250+ selector : iterable of bool
251+ The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
252+ will be yielded.
253+ compressed : iterable
254+ The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
255+
256+ Yields
257+ ------
258+ object
259+ An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
260+ while skipping `data` one ahead.
261+
262+ '''
263+ its = iter (selector )
264+ itc = iter (compressed )
265+ itd = iter (data )
266+ try :
267+ if next (its ):
268+ next (itd )
269+ yield next (itc )
270+ else :
271+ yield next (itd )
272+ except StopIteration :
273+ return
274+
275+
237276class ParamMod :
238277 '''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
239278
@@ -360,6 +399,7 @@ class Identity(torch.autograd.Function):
360399 @staticmethod
361400 def forward (ctx , * inputs ):
362401 '''Forward identity.'''
402+ ctx .mark_non_differentiable (* [elem for elem in inputs if not elem .requires_grad ])
363403 return inputs
364404
365405 @staticmethod
@@ -375,62 +415,94 @@ def __init__(self):
375415 self .active = True
376416 self .tensor_handles = RemovableHandleList ()
377417
378- def pre_forward (self , module , input ):
418+ @staticmethod
419+ def _inject_grad_fn (args ):
420+ tensor_mask = tuple (isinstance (elem , torch .Tensor ) for elem in args )
421+ tensors = tuple (compress (args , tensor_mask ))
422+ # tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
423+
424+ # only if gradient required
425+ if not any (tensor .requires_grad for tensor in tensors ):
426+ return None , args , tensor_mask
427+
428+ # add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
429+ post_tensors = Identity .apply (* tensors )
430+ grad_fn = next ((tensor .grad_fn for tensor in post_tensors if tensor .grad_fn is not None ), None )
431+ if grad_fn is None :
432+ raise RuntimeError ('Backward hook could not be registered!' )
433+
434+ # work-around to support in-place operations
435+ post_tensors = tuple (elem .clone () for elem in post_tensors )
436+ post_args = tuple (uncompress (args , tensor_mask , post_tensors ))
437+ return grad_fn , post_args , tensor_mask
438+
439+ def pre_forward (self , module , args , kwargs ):
379440 '''Apply an Identity to the input before the module to register a backward hook.'''
380441 hook_ref = weakref .ref (self )
381442
443+ grad_fn , post_args , input_tensor_mask = self ._inject_grad_fn (args )
444+ if grad_fn is None :
445+ return
446+
382447 @functools .wraps (self .backward )
383448 def wrapper (grad_input , grad_output ):
384449 hook = hook_ref ()
385450 if hook is not None and hook .active :
386- return hook .backward (module , grad_input , hook .stored_tensors ['grad_output' ])
451+ return hook .backward (
452+ module ,
453+ list (uncompress (
454+ repeat (None ),
455+ input_tensor_mask ,
456+ grad_input ,
457+ )),
458+ hook .stored_tensors ['grad_output' ],
459+ )
387460 return None
388461
389- if not isinstance ( input , tuple ):
390- input = ( input , )
462+ # register the input tensor gradient hook
463+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
391464
392- # only if gradient required
393- if input [0 ].requires_grad :
394- # add identity to ensure .grad_fn exists
395- post_input = Identity .apply (* input )
396- # register the input tensor gradient hook
397- self .tensor_handles .append (
398- post_input [0 ].grad_fn .register_hook (wrapper )
399- )
400- # work around to support in-place operations
401- post_input = tuple (elem .clone () for elem in post_input )
402- else :
403- # no gradient required
404- post_input = input
405- return post_input [0 ] if len (post_input ) == 1 else post_input
465+ return post_args , kwargs
406466
407- def post_forward (self , module , input , output ):
467+ def post_forward (self , module , args , kwargs , output ):
408468 '''Register a backward-hook to the resulting tensor right after the forward.'''
409469 hook_ref = weakref .ref (self )
410470
471+ single = not isinstance (output , tuple )
472+ if single :
473+ output = (output ,)
474+
475+ grad_fn , post_output , output_tensor_mask = self ._inject_grad_fn (output )
476+ if grad_fn is None :
477+ return
478+
411479 @functools .wraps (self .pre_backward )
412480 def wrapper (grad_input , grad_output ):
413481 hook = hook_ref ()
414482 if hook is not None and hook .active :
415- return hook .pre_backward (module , grad_input , grad_output )
483+ return hook .pre_backward (
484+ module ,
485+ grad_input ,
486+ tuple (uncompress (
487+ repeat (None ),
488+ output_tensor_mask ,
489+ grad_output
490+ ))
491+ )
416492 return None
417493
418- if not isinstance ( output , tuple ):
419- output = ( output , )
494+ # register the output tensor gradient hook
495+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
420496
421- # only if gradient required
422- if output [0 ].grad_fn is not None :
423- # register the output tensor gradient hook
424- self .tensor_handles .append (
425- output [0 ].grad_fn .register_hook (wrapper )
426- )
427- return output [0 ] if len (output ) == 1 else output
497+ if single :
498+ return post_output [0 ]
499+ return post_output
428500
429501 def pre_backward (self , module , grad_input , grad_output ):
430502 '''Store the grad_output for the backward hook'''
431503 self .stored_tensors ['grad_output' ] = grad_output
432504
433- def forward (self , module , input , output ):
505+ def forward (self , module , args , kwargs , output ):
434506 '''Hook applied during forward-pass'''
435507
436508 def backward (self , module , grad_input , grad_output ):
@@ -449,11 +521,14 @@ def remove(self):
449521
450522 def register (self , module ):
451523 '''Register this instance by registering all hooks to the supplied module.'''
524+ # assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
525+ forward_params = signature (self .forward ).parameters
526+ with_kwargs = len (forward_params ) != 3 and list (forward_params )[2 ] != 'output'
452527 return RemovableHandleList ([
453528 RemovableHandle (self ),
454- module .register_forward_pre_hook (self .pre_forward ),
455- module .register_forward_hook (self .post_forward ),
456- module .register_forward_hook (self .forward ),
529+ module .register_forward_pre_hook (self .pre_forward , with_kwargs = True ),
530+ module .register_forward_hook (self .post_forward , with_kwargs = True ),
531+ module .register_forward_hook (self .forward , with_kwargs = with_kwargs ),
457532 ])
458533
459534
@@ -517,31 +592,61 @@ def __init__(
517592 self .gradient_mapper = gradient_mapper
518593 self .reducer = reducer
519594
520- def forward (self , module , input , output ):
595+ def forward (self , module , args , kwargs , output ):
521596 '''Forward hook to save module in-/outputs.'''
522- self .stored_tensors ['input' ] = input
597+ self .stored_tensors ['input' ] = args
598+ self .stored_tensors ['kwargs' ] = kwargs
523599
524600 def backward (self , module , grad_input , grad_output ):
525601 '''Backward hook to compute LRP based on the class attributes.'''
526- original_input = self .stored_tensors ['input' ][0 ].clone ()
602+ input_mask = [elem is not None for elem in self .stored_tensors ['input' ]]
603+ output_mask = [elem is not None for elem in grad_output ]
604+ cgrad_output = tuple (compress (grad_output , output_mask ))
605+
606+ original_inputs = [tensor .clone () for tensor in self .stored_tensors ['input' ]]
607+ kwargs = self .stored_tensors ['kwargs' ]
527608 inputs = []
528609 outputs = []
529610 for in_mod , param_mod , out_mod in zip (self .input_modifiers , self .param_modifiers , self .output_modifiers ):
530- input = in_mod (original_input ).requires_grad_ ()
611+ mod_args = (in_mod (tensor ).requires_grad_ () for tensor in compress (original_inputs , input_mask ))
612+ args = tuple (uncompress (original_inputs , input_mask , mod_args ))
531613 with ParamMod .ensure (param_mod )(module ) as modified , torch .autograd .enable_grad ():
532- output = modified .forward (input )
533- output = out_mod (output )
534- inputs .append (input )
614+ output = modified .forward (* args , ** kwargs )
615+ if not isinstance (output , tuple ):
616+ output = (output ,)
617+ output = tuple (out_mod (tensor ) for tensor in compress (output , output_mask ))
618+ inputs .append (compress (args , input_mask ))
535619 outputs .append (output )
536- grad_outputs = self .gradient_mapper (grad_output [0 ], outputs )
537- gradients = torch .autograd .grad (
538- outputs ,
539- inputs ,
620+
621+ inputs = list (zip (* inputs ))
622+ outputs = list (zip (* outputs ))
623+ input_struct = [len (elem ) for elem in inputs ]
624+ output_struct = [len (elem ) for elem in outputs ]
625+
626+ grad_outputs = tuple (
627+ self .gradient_mapper (gradout , outs )
628+ for gradout , outs in zip (cgrad_output , outputs )
629+ )
630+ inputs_flat = tuple (chain .from_iterable (inputs ))
631+ outputs_flat = tuple (chain .from_iterable (outputs ))
632+ if not all (isinstance (elem , torch .Tensor ) for elem in grad_outputs ):
633+ # if there is only a single output modifier, grad_outputs may contain tensors
634+ grad_outputs = tuple (chain .from_iterable (grad_outputs ))
635+
636+ gradients_flat = torch .autograd .grad (
637+ outputs_flat ,
638+ inputs_flat ,
540639 grad_outputs = grad_outputs ,
541- create_graph = grad_output [ 0 ] .requires_grad
640+ create_graph = any ( tensor .requires_grad for tensor in cgrad_output )
542641 )
543- relevance = self .reducer (inputs , gradients )
544- return tuple (relevance if original .shape == relevance .shape else None for original in grad_input )
642+
643+ # input_it = iter(inputs)
644+ # inputs_re = [tuple(islice(input_it, size)) for size in input_struct]
645+ gradient_it = iter (gradients_flat )
646+ gradients = [tuple (islice (gradient_it , size )) for size in input_struct ]
647+
648+ relevances = (self .reducer (inp , grad ) for inp , grad in zip (inputs , gradients ))
649+ return tuple (uncompress (repeat (None ), input_mask , relevances ))
545650
546651 def copy (self ):
547652 '''Return a copy of this hook.
0 commit comments