-
Notifications
You must be signed in to change notification settings - Fork 35
Description
Hi,
I am attempting to implement Conservative LRP rules as part of #184 since I need it for a university project. I am running into certain issues with the classifier heads and was hoping you could point out what I'm potentially doing wrong.
I have so far implemented a composite in the following manner:
@register_composite('transformer')
class Transformer(LayerMapComposite):
def __init__(self, linear_layer_epsilon=1e-6, layer_norm_epsilon=1e-5, layer_map=None, canonizers=None, zero_params=None):
if layer_map is None:
layer_map = []
rule_kwargs = {'zero_params': zero_params}
layer_map = [(Activation, Pass()),
(Convolution, ZPlus()),
(AvgPool, Norm()),
(MultiheadAttention, AHConservative()),
(LayerNorm, LNConservative(epsilon=layer_norm_epsilon)),
(Linear, Epsilon(epsilon=linear_layer_epsilon, **rule_kwargs))]
# layer_map.composites += general_layer_map
# named_map = NameLayerMapComposite([(("classifier",), )])
super().__init__(layer_map=layer_map, canonizers=canonizers)Where AHConservative and LNConservative are Rules described in the CLRP Paper.
I have also implemented a custom attributor which calculates relevance scores with respect to the embeddings (since integer inputs are not differentiable). However the problem I am facing seems to appear also with a basic Gradient attributor. Here is a minimal example:
import torch
from torch import nn
from zennit.composites import Transformer
from zennit.attribution import Gradient
composite = Transformer()
model = nn.Linear(768, 2, bias=True)
with Gradient(model=model, composite=composite) as attributor:
input = torch.randn((1, 768))
out, relevance = attributor(input.float(), attr_output=torch.ones((1, 2)))Working with BertForSequenceClassification from the Huggingface transformers library the classifier head is an nn.Linear module with the size (768, 2). I am however getting an error from the Epsilon rule, specifically from the gradient_mapper:
Exception has occurred: RuntimeError (note: full exception trace is shown but execution is paused at: wrapper)
The size of tensor a (768) must match the size of tensor b (2) at non-singleton dimension 1
File "/home/chris/zennit/src/zennit/rules.py", line 120, in <lambda>
gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])),
File "/home/chris/zennit/src/zennit/core.py", line 539, in backward
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
File "/home/chris/zennit/src/zennit/core.py", line 388, in wrapper (Current frame)
return hook.backward(module, grad_input, grad_output)
File "/home/chris/zennit/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 303, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/chris/zennit/src/zennit/attribution.py", line 257, in grad
gradient, = torch.autograd.grad(
File "/home/chris/zennit/src/zennit/attribution.py", line 287, in forward
return self.grad(input, attr_output_fn)
File "/home/chris/zennit/src/zennit/attribution.py", line 181, in __call__
return self.forward(input, attr_output_fn)
File "/home/chris/zennit/t3.py", line 25, in <module>
out, relevance = attributor(input.float(), attr_output=torch.ones((1, 2)))
File "/home/chris/miniconda3/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/chris/miniconda3/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
RuntimeError: The size of tensor a (768) must match the size of tensor b (2) at non-singleton dimension 1
I do understand that the tensor sizes must agree for division in the gradient_mapper. I am therefore suspecting I'm mishandling the classifier head but I am not sure how to proceed. Should I upsample the output to the size of input and just use Epsilon? Should I implement a custom rule? Any help would be largely appreciated! I'd love to get a rough implementation of the changes suggested in #184 working so we could push the needle on XAI for Transformers a bit.