Skip to content

Commit 15738d0

Browse files
csauperfacebook-github-bot
authored andcommitted
Fix more pyre errors for llm_attr and tests [2/n] (#1359)
Summary: Pull Request resolved: #1359 Get rid of some more pyre errors by adding better typing Reviewed By: uberblah Differential Revision: D63365945 fbshipit-source-id: 15b221d93e7f701d2a2bac9cc3d9ea2be937e850
1 parent 0af8bd6 commit 15738d0

File tree

3 files changed

+61
-74
lines changed

3 files changed

+61
-74
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def plot_token_attr(
7979
"token_attr is None (no token-level attribution was performed), please "
8080
"use plot_seq_attr instead for the sequence-level attribution plot"
8181
)
82-
token_attr = self.token_attr.cpu() # type: ignore
82+
token_attr = self.token_attr.cpu()
8383

8484
# maximum absolute attribution value
8585
# used as the boundary of normalization
@@ -343,7 +343,7 @@ def _forward_func(
343343
caching=use_cached_outputs,
344344
)
345345

346-
log_prob_list = []
346+
log_prob_list: List[Tensor] = []
347347
outputs = None
348348
for target_token in target_tokens:
349349
if use_cached_outputs:
@@ -382,17 +382,15 @@ def _forward_func(
382382
(model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1
383383
)
384384

385-
# pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return
386-
# Tensor
387-
total_log_prob: Tensor = sum(log_prob_list) # type: ignore
385+
total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0)
388386
# 1st element is the total prob, rest are the target tokens
389387
# add a leading dim for batch even we only support single instance for now
390388
if self.include_per_token_attr:
391389
target_log_probs = torch.stack(
392-
[total_log_prob, *log_prob_list], dim=0 # type: ignore
390+
[total_log_prob, *log_prob_list], dim=0
393391
).unsqueeze(0)
394392
else:
395-
target_log_probs = total_log_prob # type: ignore
393+
target_log_probs = total_log_prob
396394
target_probs = torch.exp(target_log_probs)
397395

398396
if _inspect_forward:
@@ -412,9 +410,9 @@ def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor:
412410
"""
413411
# return tensor(1, n_tokens)
414412
if isinstance(model_input, str):
415-
return self.tokenizer.encode( # type: ignore
416-
model_input, return_tensors="pt"
417-
).to(self.device)
413+
return self.tokenizer.encode(model_input, return_tensors="pt").to(
414+
self.device
415+
)
418416
return model_input.to(self.device)
419417

420418
def attribute(
@@ -544,8 +542,7 @@ def attribute(
544542
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
545543
)
546544

547-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
548-
def attribute_future(self) -> Callable:
545+
def attribute_future(self) -> Callable[[], LLMAttributionResult]:
549546
r"""
550547
This method is not implemented for LLMAttribution.
551548
"""
@@ -612,9 +609,9 @@ def _format_model_input(self, model_input: Union[Tensor, str]) -> Tensor:
612609
Convert str to tokenized tensor
613610
"""
614611
if isinstance(model_input, str):
615-
return self.tokenizer.encode( # type: ignore
616-
model_input, return_tensors="pt"
617-
).to(self.device)
612+
return self.tokenizer.encode(model_input, return_tensors="pt").to(
613+
self.device
614+
)
618615
return model_input.to(self.device)
619616

620617
def attribute(
@@ -745,8 +742,7 @@ def attribute(
745742
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
746743
)
747744

748-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
749-
def attribute_future(self) -> Callable:
745+
def attribute_future(self) -> Callable[[], LLMAttributionResult]:
750746
r"""
751747
This method is not implemented for LLMGradientAttribution.
752748
"""

tests/attr/test_llm_attr.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
cast,
99
Dict,
1010
List,
11+
Literal,
1112
NamedTuple,
1213
Optional,
1314
overload,
@@ -18,7 +19,6 @@
1819

1920
import torch
2021
from captum._utils.models.linear_model import SkLearnLasso
21-
from captum._utils.typing import Literal
2222
from captum.attr._core.feature_ablation import FeatureAblation
2323
from captum.attr._core.kernel_shap import KernelShap
2424
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
@@ -44,9 +44,6 @@ class DummyTokenizer:
4444
@overload
4545
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
4646
@overload
47-
# pyre-fixme[43]: Incompatible overload. The implementation of
48-
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
49-
# pyre-ignore[11]: Annotation `pt` is not defined as a type
5047
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
5148

5249
def encode(
@@ -393,9 +390,6 @@ def test_llm_attr_without_token(
393390
"m n o p q",
394391
skip_tokens=[0],
395392
use_cached_outputs=self.use_cached_outputs,
396-
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
397-
# for 4th positional argument, expected
398-
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
399393
**attr_kws, # type: ignore
400394
)
401395

@@ -439,10 +433,10 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:
439433

440434
# 5 output tokens, 4 input tokens including sos
441435
self.assertEqual(res.seq_attr.shape, (4,))
442-
assert res.token_attr is not None # make pyre/mypy happy
436+
assert res.token_attr is not None
443437
self.assertIsNotNone(res.token_attr)
444438
token_attr = res.token_attr
445-
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
439+
self.assertEqual(token_attr.shape, (6, 4))
446440
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
447441
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
448442

@@ -462,18 +456,17 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:
462456

463457
# 5 output tokens, 4 input tokens including sos
464458
self.assertEqual(res.seq_attr.shape, (4,))
465-
assert res.token_attr is not None # make pyre/mypy happy
459+
assert res.token_attr is not None
466460
self.assertIsNotNone(res.token_attr)
467461
token_attr = res.token_attr
468-
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
462+
self.assertEqual(token_attr.shape, (5, 4))
469463
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
470464
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
471465

472466

473467
@parameterized_class(
474468
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
475469
)
476-
# pyre-fixme[13]: Attribute `device` is never initialized.
477470
class TestLLMGradAttr(BaseTest):
478471
# pyre-fixme[13]: Attribute `device` is never initialized.
479472
device: str
@@ -505,16 +498,16 @@ def test_llm_attr(
505498

506499
# 5 output tokens, 4 input tokens including sos
507500
self.assertEqual(res.seq_attr.shape, (4,))
508-
assert res.token_attr is not None # make pyre/mypy happy
501+
assert res.token_attr is not None
509502
self.assertIsNotNone(res.token_attr)
510503
token_attr = res.token_attr
511-
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
504+
self.assertEqual(token_attr.shape, (5, 4))
512505
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
513506
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
514507

515508
self.assertEqual(res.seq_attr.device.type, self.device)
516-
assert res.token_attr is not None # make pyre/mypy happy
517-
self.assertEqual(token_attr.device.type, self.device) # type: ignore
509+
assert res.token_attr is not None
510+
self.assertEqual(token_attr.device.type, self.device)
518511

519512
@parameterized.expand(
520513
[
@@ -542,16 +535,16 @@ def test_llm_attr_without_target(
542535
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}, **attr_kws)
543536

544537
self.assertEqual(res.seq_attr.shape, (4,))
545-
assert res.token_attr is not None # make pyre/mypy happy
538+
assert res.token_attr is not None
546539
self.assertIsNotNone(res.token_attr)
547540
token_attr = res.token_attr
548-
self.assertEqual(token_attr.shape, (3, 4)) # type: ignore
541+
self.assertEqual(token_attr.shape, (3, 4))
549542
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
550543
self.assertEqual(res.output_tokens, ["x", "y", "z"])
551544

552545
self.assertEqual(res.seq_attr.device.type, self.device)
553-
assert res.token_attr is not None # make pyre/mypy happy
554-
self.assertEqual(token_attr.device.type, self.device) # type: ignore
546+
assert res.token_attr is not None
547+
self.assertEqual(token_attr.device.type, self.device)
555548

556549
@parameterized.expand(
557550
[
@@ -580,16 +573,16 @@ def test_llm_attr_with_skip_tokens(
580573

581574
# 5 output tokens, 4 input tokens including sos
582575
self.assertEqual(res.seq_attr.shape, (3,))
583-
assert res.token_attr is not None # make pyre/mypy happy
576+
assert res.token_attr is not None
584577
self.assertIsNotNone(res.token_attr)
585578
token_attr = res.token_attr
586-
self.assertEqual(token_attr.shape, (5, 3)) # type: ignore
579+
self.assertEqual(token_attr.shape, (5, 3))
587580
self.assertEqual(res.input_tokens, ["a", "b", "c"])
588581
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
589582

590583
self.assertEqual(res.seq_attr.device.type, self.device)
591-
assert res.token_attr is not None # make pyre/mypy happy
592-
self.assertEqual(token_attr.device.type, self.device) # type: ignore
584+
assert res.token_attr is not None
585+
self.assertEqual(token_attr.device.type, self.device)
593586

594587
def test_llm_attr_with_no_skip_tokens(self) -> None:
595588
llm = DummyLLM()
@@ -602,12 +595,12 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:
602595
inp = TextTokenInput("a b c", tokenizer)
603596
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
604597

605-
# 5 output tokens, 4 input tokens including sos
598+
# 6 output tokens, 4 input tokens including sos
606599
self.assertEqual(res.seq_attr.shape, (4,))
607-
assert res.token_attr is not None # make pyre/mypy happy
600+
assert res.token_attr is not None
608601
self.assertIsNotNone(res.token_attr)
609602
token_attr = res.token_attr
610-
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
603+
self.assertEqual(token_attr.shape, (6, 4))
611604
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
612605
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
613606

@@ -629,9 +622,9 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:
629622

630623
# 5 output tokens, 4 input tokens including sos
631624
self.assertEqual(res.seq_attr.shape, (4,))
632-
assert res.token_attr is not None # make pyre/mypy happy
625+
assert res.token_attr is not None
633626
self.assertIsNotNone(res.token_attr)
634627
token_attr = res.token_attr
635-
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
628+
self.assertEqual(token_attr.shape, (5, 4))
636629
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
637630
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

tests/attr/test_llm_attr_gpu.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,21 @@
33
# pyre-strict
44

55
import copy
6-
from typing import Any, cast, Dict, List, NamedTuple, Optional, overload, Type, Union
6+
from typing import (
7+
Any,
8+
cast,
9+
Dict,
10+
List,
11+
Literal,
12+
NamedTuple,
13+
Optional,
14+
overload,
15+
Type,
16+
Union,
17+
)
718

819
import torch
920

10-
from captum._utils.typing import Literal
1121
from captum.attr._core.feature_ablation import FeatureAblation
1222
from captum.attr._core.kernel_shap import KernelShap
1323
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
@@ -32,9 +42,6 @@ class DummyTokenizer:
3242
@overload
3343
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
3444
@overload
35-
# pyre-fixme[43]: Incompatible overload. The implementation of
36-
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
37-
# pyre-ignore[11]: Annotation `pt` is not defined as a type
3845
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
3946

4047
def encode(
@@ -122,9 +129,6 @@ def generate(
122129
assert mock_response, "must mock response to use DummyLLM to geenrate"
123130
response = self.tokenizer.encode(mock_response)[1:]
124131
return torch.cat(
125-
# pyre-fixme[6]: In call `torch._C._VariableFunctions.cat`,
126-
# for 1st positional argument, expected `Union[List[Tensor],
127-
# typing.Tuple[Tensor, ...]]` but got `List[Union[List[int], Tensor]]`.
128132
[input_ids, torch.tensor([response], device=self.device)], # type: ignore
129133
dim=1,
130134
)
@@ -178,10 +182,6 @@ def device(self) -> torch._C.device:
178182
else [("cpu", True), ("cpu", False)]
179183
),
180184
)
181-
# pyre-fixme[13]: Attribute `device` is declared in class `TestLlmAttrGpu`
182-
# to have type `str` but is never initialized.
183-
# pyre-fixme[13]: Attribute `use_cached_outputs` is declared in class `TestLlmAttrGpu`
184-
# to have type `bool` but is never initialized.
185185
class TestLlmAttrGpu(BaseTest):
186186
# pyre-fixme[13]: Attribute `device` is never initialized.
187187
device: str
@@ -277,8 +277,6 @@ def test_llm_attr_without_token_gpu(
277277
@parameterized_class(
278278
("device",), [("cuda",)] if torch.cuda.is_available() else [("cpu",)]
279279
)
280-
# pyre-fixme[13]: Attribute `device` is declared in class `TestLLMGradAttrGPU`
281-
# to have type `str` but is never initialized.
282280
class TestLLMGradAttrGPU(BaseTest):
283281
# pyre-fixme[13]: Attribute `device` is never initialized.
284282
device: str
@@ -294,16 +292,16 @@ def test_llm_attr(self) -> None:
294292
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
295293
# 5 output tokens, 4 input tokens including sos
296294
self.assertEqual(res.seq_attr.shape, (4,))
297-
assert res.token_attr is not None # make pyre/mypy happy
295+
assert res.token_attr is not None
298296
self.assertIsNotNone(res.token_attr)
299297
token_attr = res.token_attr
300-
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
298+
self.assertEqual(token_attr.shape, (5, 4))
301299
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
302300
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
303301

304302
self.assertEqual(res.seq_attr.device.type, self.device)
305-
assert res.token_attr is not None # make pyre/mypy happy
306-
self.assertEqual(token_attr.device.type, self.device) # type: ignore
303+
assert res.token_attr is not None
304+
self.assertEqual(token_attr.device.type, self.device)
307305

308306
def test_llm_attr_without_target(self) -> None:
309307
llm = DummyLLM()
@@ -316,16 +314,16 @@ def test_llm_attr_without_target(self) -> None:
316314
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"})
317315

318316
self.assertEqual(res.seq_attr.shape, (4,))
319-
assert res.token_attr is not None # make pyre/mypy happy
317+
assert res.token_attr is not None
320318
self.assertIsNotNone(res.token_attr)
321319
token_attr = res.token_attr
322-
self.assertEqual(token_attr.shape, (3, 4)) # type: ignore
320+
self.assertEqual(token_attr.shape, (3, 4))
323321
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
324322
self.assertEqual(res.output_tokens, ["x", "y", "z"])
325323

326-
self.assertEqual(res.seq_attr.device.type, self.device) # type: ignore
327-
assert res.token_attr is not None # make pyre/mypy happy
328-
self.assertEqual(token_attr.device.type, self.device) # type: ignore
324+
self.assertEqual(res.seq_attr.device.type, self.device)
325+
assert res.token_attr is not None
326+
self.assertEqual(token_attr.device.type, self.device)
329327

330328
def test_llm_attr_with_skip_tokens(self) -> None:
331329
llm = DummyLLM()
@@ -337,15 +335,15 @@ def test_llm_attr_with_skip_tokens(self) -> None:
337335
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
338336
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
339337

340-
# 5 output tokens, 4 input tokens including sos
338+
# 5 output tokens, 3 input tokens including sos
341339
self.assertEqual(res.seq_attr.shape, (3,))
342-
assert res.token_attr is not None # make pyre/mypy happy
340+
assert res.token_attr is not None
343341
self.assertIsNotNone(res.token_attr)
344342
token_attr = res.token_attr
345-
self.assertEqual(token_attr.shape, (5, 3)) # type: ignore
343+
self.assertEqual(token_attr.shape, (5, 3))
346344
self.assertEqual(res.input_tokens, ["a", "b", "c"])
347345
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
348346

349347
self.assertEqual(res.seq_attr.device.type, self.device)
350-
assert res.token_attr is not None # make pyre/mypy happy
351-
self.assertEqual(token_attr.device.type, self.device) # type: ignore
348+
assert res.token_attr is not None
349+
self.assertEqual(token_attr.device.type, self.device)

0 commit comments

Comments
 (0)