Skip to content

Commit 4f8caeb

Browse files
DianjingLiufacebook-github-bot
authored andcommitted
Add unit test verifying compatibility with huggingface models (#1352)
Summary: Pull Request resolved: #1352 Our current unit tests for LLM Attribution use mocked models which are similar to huggingface transformer models (e.g. Llama, Llama2), but may have some unexpected differences such as [this](https://discuss.pytorch.org/t/trying-to-explain-zephyr-generative-llm/195262/3?fbclid=IwZXh0bgNhZW0CMTEAAR3REGbJsdhbNqG5LAyQ9_2J-82nPmNjt5avVyvNw-l8SMTWVXfI2DqIE8w_aem_GRP8EzELKtqDXDMZmox3Uw). To validate coverage and ensure compatibility with future changes to models, we would like to add tests using huggingface models directly and validate compatibility with LLM Attribution, which will help us quickly catch any breaking changes. So far we only test for model type `LlamaForCausalLM` Reviewed By: vivekmig Differential Revision: D62894898 fbshipit-source-id: 910be92cabd5a8c428a89fef3689dfc4110a9417
1 parent 49d8689 commit 4f8caeb

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

scripts/install_via_conda.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ fi
3737
# install other deps
3838
conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug==2.2.2
3939
conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress
40+
conda install -q -y transformers
4041

4142
# install captum
4243
python setup.py develop

scripts/install_via_pip.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,5 @@ fi
6565
if [[ $DEPLOY == true ]]; then
6666
pip install beautifulsoup4 ipython nbconvert==5.6.1 --progress-bar off
6767
fi
68+
69+
pip install transformers --progress-bar off
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#!/usr/bin/env python3
2+
3+
4+
from typing import cast, Dict, Optional, Type
5+
6+
import torch
7+
from captum.attr._core.feature_ablation import FeatureAblation
8+
from captum.attr._core.llm_attr import LLMAttribution
9+
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
10+
from captum.attr._utils.attribution import PerturbationAttribution
11+
from captum.attr._utils.interpretable_input import TextTemplateInput
12+
from parameterized import parameterized, parameterized_class
13+
from tests.helpers import BaseTest
14+
from torch import Tensor
15+
16+
HAS_HF = True
17+
try:
18+
# pyre-fixme[21]: Could not find a module corresponding to import `transformers`
19+
from transformers import AutoModelForCausalLM, AutoTokenizer
20+
except ImportError:
21+
HAS_HF = False
22+
23+
24+
@parameterized_class(
25+
("device", "use_cached_outputs"),
26+
(
27+
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
28+
if torch.cuda.is_available()
29+
else [("cpu", True), ("cpu", False)]
30+
),
31+
)
32+
class TestLLMAttrHFCompatibility(BaseTest):
33+
# pyre-fixme[13]: Attribute `device` is never initialized.
34+
device: str
35+
# pyre-fixme[13]: Attribute `use_cached_outputs` is never initialized.
36+
use_cached_outputs: bool
37+
38+
def setUp(self) -> None:
39+
if not HAS_HF:
40+
self.skipTest("transformers package not found, skipping tests")
41+
super().setUp()
42+
43+
# pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension
44+
@parameterized.expand(
45+
[
46+
(
47+
AttrClass,
48+
n_samples,
49+
)
50+
for AttrClass, n_samples in zip(
51+
(FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass
52+
(None, 1000, None), # n_samples
53+
)
54+
]
55+
)
56+
def test_llm_attr_hf_compatibility(
57+
self,
58+
AttrClass: Type[PerturbationAttribution],
59+
n_samples: Optional[int],
60+
) -> None:
61+
attr_kws: Dict[str, int] = {}
62+
if n_samples is not None:
63+
attr_kws["n_samples"] = n_samples
64+
65+
tokenizer = AutoTokenizer.from_pretrained(
66+
"hf-internal-testing/tiny-random-LlamaForCausalLM"
67+
)
68+
llm = AutoModelForCausalLM.from_pretrained(
69+
"hf-internal-testing/tiny-random-LlamaForCausalLM"
70+
)
71+
72+
llm.to(self.device)
73+
llm.eval()
74+
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)
75+
76+
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
77+
res = llm_attr.attribute(
78+
inp,
79+
"m n o p q",
80+
use_cached_outputs=self.use_cached_outputs,
81+
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
82+
# for 4th positional argument, expected
83+
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
84+
**attr_kws, # type: ignore
85+
)
86+
self.assertEqual(res.seq_attr.shape, (4,))
87+
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
88+
self.assertEqual(res.seq_attr.device.type, self.device)
89+
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)

0 commit comments

Comments
 (0)