|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | + |
| 4 | +from typing import cast, Dict, Optional, Type |
| 5 | + |
| 6 | +import torch |
| 7 | +from captum.attr._core.feature_ablation import FeatureAblation |
| 8 | +from captum.attr._core.llm_attr import LLMAttribution |
| 9 | +from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling |
| 10 | +from captum.attr._utils.attribution import PerturbationAttribution |
| 11 | +from captum.attr._utils.interpretable_input import TextTemplateInput |
| 12 | +from parameterized import parameterized, parameterized_class |
| 13 | +from tests.helpers import BaseTest |
| 14 | +from torch import Tensor |
| 15 | + |
| 16 | +HAS_HF = True |
| 17 | +try: |
| 18 | + # pyre-fixme[21]: Could not find a module corresponding to import `transformers` |
| 19 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 20 | +except ImportError: |
| 21 | + HAS_HF = False |
| 22 | + |
| 23 | + |
| 24 | +@parameterized_class( |
| 25 | + ("device", "use_cached_outputs"), |
| 26 | + ( |
| 27 | + [("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)] |
| 28 | + if torch.cuda.is_available() |
| 29 | + else [("cpu", True), ("cpu", False)] |
| 30 | + ), |
| 31 | +) |
| 32 | +class TestLLMAttrHFCompatibility(BaseTest): |
| 33 | + # pyre-fixme[13]: Attribute `device` is never initialized. |
| 34 | + device: str |
| 35 | + # pyre-fixme[13]: Attribute `use_cached_outputs` is never initialized. |
| 36 | + use_cached_outputs: bool |
| 37 | + |
| 38 | + def setUp(self) -> None: |
| 39 | + if not HAS_HF: |
| 40 | + self.skipTest("transformers package not found, skipping tests") |
| 41 | + super().setUp() |
| 42 | + |
| 43 | + # pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension |
| 44 | + @parameterized.expand( |
| 45 | + [ |
| 46 | + ( |
| 47 | + AttrClass, |
| 48 | + n_samples, |
| 49 | + ) |
| 50 | + for AttrClass, n_samples in zip( |
| 51 | + (FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass |
| 52 | + (None, 1000, None), # n_samples |
| 53 | + ) |
| 54 | + ] |
| 55 | + ) |
| 56 | + def test_llm_attr_hf_compatibility( |
| 57 | + self, |
| 58 | + AttrClass: Type[PerturbationAttribution], |
| 59 | + n_samples: Optional[int], |
| 60 | + ) -> None: |
| 61 | + attr_kws: Dict[str, int] = {} |
| 62 | + if n_samples is not None: |
| 63 | + attr_kws["n_samples"] = n_samples |
| 64 | + |
| 65 | + tokenizer = AutoTokenizer.from_pretrained( |
| 66 | + "hf-internal-testing/tiny-random-LlamaForCausalLM" |
| 67 | + ) |
| 68 | + llm = AutoModelForCausalLM.from_pretrained( |
| 69 | + "hf-internal-testing/tiny-random-LlamaForCausalLM" |
| 70 | + ) |
| 71 | + |
| 72 | + llm.to(self.device) |
| 73 | + llm.eval() |
| 74 | + llm_attr = LLMAttribution(AttrClass(llm), tokenizer) |
| 75 | + |
| 76 | + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) |
| 77 | + res = llm_attr.attribute( |
| 78 | + inp, |
| 79 | + "m n o p q", |
| 80 | + use_cached_outputs=self.use_cached_outputs, |
| 81 | + # pyre-fixme[6]: In call `LLMAttribution.attribute`, |
| 82 | + # for 4th positional argument, expected |
| 83 | + # `Optional[typing.Callable[..., typing.Any]]` but got `int`. |
| 84 | + **attr_kws, # type: ignore |
| 85 | + ) |
| 86 | + self.assertEqual(res.seq_attr.shape, (4,)) |
| 87 | + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) |
| 88 | + self.assertEqual(res.seq_attr.device.type, self.device) |
| 89 | + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) |
0 commit comments