8
8
cast ,
9
9
Dict ,
10
10
List ,
11
+ Literal ,
11
12
NamedTuple ,
12
13
Optional ,
13
14
overload ,
18
19
19
20
import torch
20
21
from captum ._utils .models .linear_model import SkLearnLasso
21
- from captum ._utils .typing import Literal
22
22
from captum .attr ._core .feature_ablation import FeatureAblation
23
23
from captum .attr ._core .kernel_shap import KernelShap
24
24
from captum .attr ._core .layer .layer_gradient_shap import LayerGradientShap
@@ -44,9 +44,6 @@ class DummyTokenizer:
44
44
@overload
45
45
def encode (self , text : str , return_tensors : None = None ) -> List [int ]: ...
46
46
@overload
47
- # pyre-fixme[43]: Incompatible overload. The implementation of
48
- # `DummyTokenizer.encode` does not accept all possible arguments of overload.
49
- # pyre-ignore[11]: Annotation `pt` is not defined as a type
50
47
def encode (self , text : str , return_tensors : Literal ["pt" ]) -> Tensor : ...
51
48
52
49
def encode (
@@ -393,9 +390,6 @@ def test_llm_attr_without_token(
393
390
"m n o p q" ,
394
391
skip_tokens = [0 ],
395
392
use_cached_outputs = self .use_cached_outputs ,
396
- # pyre-fixme[6]: In call `LLMAttribution.attribute`,
397
- # for 4th positional argument, expected
398
- # `Optional[typing.Callable[..., typing.Any]]` but got `int`.
399
393
** attr_kws , # type: ignore
400
394
)
401
395
@@ -439,10 +433,10 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:
439
433
440
434
# 5 output tokens, 4 input tokens including sos
441
435
self .assertEqual (res .seq_attr .shape , (4 ,))
442
- assert res .token_attr is not None # make pyre/mypy happy
436
+ assert res .token_attr is not None
443
437
self .assertIsNotNone (res .token_attr )
444
438
token_attr = res .token_attr
445
- self .assertEqual (token_attr .shape , (6 , 4 )) # type: ignore
439
+ self .assertEqual (token_attr .shape , (6 , 4 ))
446
440
self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
447
441
self .assertEqual (res .output_tokens , ["<sos>" , "m" , "n" , "o" , "p" , "q" ])
448
442
@@ -462,18 +456,17 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:
462
456
463
457
# 5 output tokens, 4 input tokens including sos
464
458
self .assertEqual (res .seq_attr .shape , (4 ,))
465
- assert res .token_attr is not None # make pyre/mypy happy
459
+ assert res .token_attr is not None
466
460
self .assertIsNotNone (res .token_attr )
467
461
token_attr = res .token_attr
468
- self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
462
+ self .assertEqual (token_attr .shape , (5 , 4 ))
469
463
self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
470
464
self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
471
465
472
466
473
467
@parameterized_class (
474
468
("device" ,), [("cpu" ,), ("cuda" ,)] if torch .cuda .is_available () else [("cpu" ,)]
475
469
)
476
- # pyre-fixme[13]: Attribute `device` is never initialized.
477
470
class TestLLMGradAttr (BaseTest ):
478
471
# pyre-fixme[13]: Attribute `device` is never initialized.
479
472
device : str
@@ -505,16 +498,16 @@ def test_llm_attr(
505
498
506
499
# 5 output tokens, 4 input tokens including sos
507
500
self .assertEqual (res .seq_attr .shape , (4 ,))
508
- assert res .token_attr is not None # make pyre/mypy happy
501
+ assert res .token_attr is not None
509
502
self .assertIsNotNone (res .token_attr )
510
503
token_attr = res .token_attr
511
- self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
504
+ self .assertEqual (token_attr .shape , (5 , 4 ))
512
505
self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
513
506
self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
514
507
515
508
self .assertEqual (res .seq_attr .device .type , self .device )
516
- assert res .token_attr is not None # make pyre/mypy happy
517
- self .assertEqual (token_attr .device .type , self .device ) # type: ignore
509
+ assert res .token_attr is not None
510
+ self .assertEqual (token_attr .device .type , self .device )
518
511
519
512
@parameterized .expand (
520
513
[
@@ -542,16 +535,16 @@ def test_llm_attr_without_target(
542
535
res = llm_attr .attribute (inp , gen_args = {"mock_response" : "x y z" }, ** attr_kws )
543
536
544
537
self .assertEqual (res .seq_attr .shape , (4 ,))
545
- assert res .token_attr is not None # make pyre/mypy happy
538
+ assert res .token_attr is not None
546
539
self .assertIsNotNone (res .token_attr )
547
540
token_attr = res .token_attr
548
- self .assertEqual (token_attr .shape , (3 , 4 )) # type: ignore
541
+ self .assertEqual (token_attr .shape , (3 , 4 ))
549
542
self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
550
543
self .assertEqual (res .output_tokens , ["x" , "y" , "z" ])
551
544
552
545
self .assertEqual (res .seq_attr .device .type , self .device )
553
- assert res .token_attr is not None # make pyre/mypy happy
554
- self .assertEqual (token_attr .device .type , self .device ) # type: ignore
546
+ assert res .token_attr is not None
547
+ self .assertEqual (token_attr .device .type , self .device )
555
548
556
549
@parameterized .expand (
557
550
[
@@ -580,16 +573,16 @@ def test_llm_attr_with_skip_tokens(
580
573
581
574
# 5 output tokens, 4 input tokens including sos
582
575
self .assertEqual (res .seq_attr .shape , (3 ,))
583
- assert res .token_attr is not None # make pyre/mypy happy
576
+ assert res .token_attr is not None
584
577
self .assertIsNotNone (res .token_attr )
585
578
token_attr = res .token_attr
586
- self .assertEqual (token_attr .shape , (5 , 3 )) # type: ignore
579
+ self .assertEqual (token_attr .shape , (5 , 3 ))
587
580
self .assertEqual (res .input_tokens , ["a" , "b" , "c" ])
588
581
self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
589
582
590
583
self .assertEqual (res .seq_attr .device .type , self .device )
591
- assert res .token_attr is not None # make pyre/mypy happy
592
- self .assertEqual (token_attr .device .type , self .device ) # type: ignore
584
+ assert res .token_attr is not None
585
+ self .assertEqual (token_attr .device .type , self .device )
593
586
594
587
def test_llm_attr_with_no_skip_tokens (self ) -> None :
595
588
llm = DummyLLM ()
@@ -602,12 +595,12 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:
602
595
inp = TextTokenInput ("a b c" , tokenizer )
603
596
res = llm_attr .attribute (inp , "m n o p q" , ** attr_kws )
604
597
605
- # 5 output tokens, 4 input tokens including sos
598
+ # 6 output tokens, 4 input tokens including sos
606
599
self .assertEqual (res .seq_attr .shape , (4 ,))
607
- assert res .token_attr is not None # make pyre/mypy happy
600
+ assert res .token_attr is not None
608
601
self .assertIsNotNone (res .token_attr )
609
602
token_attr = res .token_attr
610
- self .assertEqual (token_attr .shape , (6 , 4 )) # type: ignore
603
+ self .assertEqual (token_attr .shape , (6 , 4 ))
611
604
self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
612
605
self .assertEqual (res .output_tokens , ["<sos>" , "m" , "n" , "o" , "p" , "q" ])
613
606
@@ -629,9 +622,9 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:
629
622
630
623
# 5 output tokens, 4 input tokens including sos
631
624
self .assertEqual (res .seq_attr .shape , (4 ,))
632
- assert res .token_attr is not None # make pyre/mypy happy
625
+ assert res .token_attr is not None
633
626
self .assertIsNotNone (res .token_attr )
634
627
token_attr = res .token_attr
635
- self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
628
+ self .assertEqual (token_attr .shape , (5 , 4 ))
636
629
self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
637
630
self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
0 commit comments