Skip to content

Commit 0af8bd6

Browse files
csauperfacebook-github-bot
authored andcommitted
Add types to fix pyre errors in InterpretableInput [1/n] (#1356)
Summary: Pull Request resolved: #1356 fix pyre errors and have better code by adding missing types in `InterpretableInput` Reviewed By: vivekmig Differential Revision: D63304771 fbshipit-source-id: 70214f0318ef32d8bdcda0f3034869e51d03440c
1 parent 9600e28 commit 0af8bd6

File tree

2 files changed

+46
-38
lines changed

2 files changed

+46
-38
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,6 @@ def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor:
412412
"""
413413
# return tensor(1, n_tokens)
414414
if isinstance(model_input, str):
415-
# pyre-ignore[9] pyre/mypy thinks return type may be List, but it will be
416-
# Tensor
417415
return self.tokenizer.encode( # type: ignore
418416
model_input, return_tensors="pt"
419417
).to(self.device)
@@ -609,10 +607,14 @@ class created with the llm model that follows huggingface style
609607
else next(self.model.parameters()).device
610608
)
611609

612-
def _format_model_input(self, model_input: Tensor) -> Tensor:
610+
def _format_model_input(self, model_input: Union[Tensor, str]) -> Tensor:
613611
"""
614612
Convert str to tokenized tensor
615613
"""
614+
if isinstance(model_input, str):
615+
return self.tokenizer.encode( # type: ignore
616+
model_input, return_tensors="pt"
617+
).to(self.device)
616618
return model_input.to(self.device)
617619

618620
def attribute(

captum/attr/_utils/interpretable_input.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pyre-strict
22
from abc import ABC, abstractmethod
3-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
3+
from typing import Callable, cast, Dict, List, Optional, Tuple, Union
44

55
import torch
66

@@ -116,17 +116,19 @@ def to_tensor(self) -> Tensor:
116116
pass
117117

118118
@abstractmethod
119-
# pyre-fixme[3]: Return annotation cannot be `Any`.
120-
def to_model_input(self, itp_tensor: Optional[Tensor] = None) -> Any:
119+
def to_model_input(
120+
self, perturbed_tensor: Optional[Tensor] = None
121+
) -> Union[str, Tensor]:
121122
"""
122123
Get the (perturbed) input in the format required by the model
123124
based on the given (perturbed) interpretable representation.
124125
125126
Args:
126127
127-
itp_tensor (Tensor, optional): tensor of the interpretable representation
128-
of this input. If it is None, assume the interpretable
129-
representation is pristine and return the original model input
128+
perturbed_tensor (Tensor, optional): tensor of the interpretable
129+
representation of this input. If it is None, assume the
130+
interpretable representation is pristine and return the
131+
original model input
130132
Default: None.
131133
132134
@@ -198,13 +200,25 @@ class TextTemplateInput(InterpretableInput):
198200
199201
"""
200202

203+
values: List[str]
204+
dict_keys: List[str]
205+
baselines: Union[List[str], Callable[[], Union[List[str], Dict[str, str]]]]
206+
n_features: int
207+
n_itp_features: int
208+
format_fn: Callable[..., str]
209+
mask: Union[List[int], Dict[str, int], None]
210+
formatted_mask: List[int]
211+
201212
def __init__(
202213
self,
203-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
204-
template: Union[str, Callable],
214+
template: Union[str, Callable[..., str]],
205215
values: Union[List[str], Dict[str, str]],
206-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
207-
baselines: Union[List[str], Dict[str, str], Callable, None] = None,
216+
baselines: Union[
217+
List[str],
218+
Dict[str, str],
219+
Callable[[], Union[List[str], Dict[str, str]]],
220+
None,
221+
] = None,
208222
mask: Union[List[int], Dict[str, int], None] = None,
209223
) -> None:
210224
# convert values dict to list
@@ -217,8 +231,8 @@ def __init__(
217231
), f"the values must be either a list or a dict, received: {type(values)}"
218232
dict_keys = []
219233

220-
self.values: List[str] = values
221-
self.dict_keys: List[str] = dict_keys
234+
self.values = values
235+
self.dict_keys = dict_keys
222236

223237
n_features = len(values)
224238

@@ -261,15 +275,12 @@ def __init__(
261275

262276
# internal compressed mask of continuous interpretable indices from 0
263277
# cannot replace original mask of ids for grouping across values externally
264-
# pyre-fixme[4]: Attribute must be annotated.
265278
self.formatted_mask = [mask_id_to_idx[mid] for mid in mask]
266279

267280
n_itp_features = len(mask_ids)
268281

269282
# number of raw features and intepretable features
270-
# pyre-fixme[4]: Attribute must be annotated.
271283
self.n_features = n_features
272-
# pyre-fixme[4]: Attribute must be annotated.
273284
self.n_itp_features = n_itp_features
274285

275286
if isinstance(template, str):
@@ -280,7 +291,6 @@ def __init__(
280291
f"received: {type(template)}"
281292
)
282293
template = template
283-
# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
284294
self.format_fn = template
285295

286296
self.mask = mask
@@ -289,8 +299,6 @@ def to_tensor(self) -> torch.Tensor:
289299
# Interpretable representation in shape(1, n_itp_features)
290300
return torch.tensor([[1.0] * self.n_itp_features])
291301

292-
# pyre-fixme[14]: `to_model_input` overrides method defined in
293-
# `InterpretableInput` inconsistently.
294302
def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str:
295303
values = list(self.values) # clone
296304

@@ -321,18 +329,12 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str:
321329
itp_val = perturbed_tensor[0][itp_idx]
322330

323331
if not itp_val:
324-
# pyre-fixme[16]: Item `None` of `Union[None, Dict[str, str],
325-
# List[typing.Any]]` has no attribute `__getitem__`.
326332
values[i] = baselines[i]
327333

328334
if self.dict_keys:
329335
dict_values = dict(zip(self.dict_keys, values))
330-
# pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not
331-
# a function.
332336
input_str = self.format_fn(**dict_values)
333337
else:
334-
# pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not
335-
# a function.
336338
input_str = self.format_fn(*values)
337339

338340
return input_str
@@ -391,6 +393,14 @@ class TextTokenInput(InterpretableInput):
391393
392394
"""
393395

396+
inp_tensor: Tensor
397+
itp_tensor: Tensor
398+
itp_mask: Optional[Tensor]
399+
values: List[str]
400+
tokenizer: TokenizerLike
401+
n_itp_features: int
402+
baselines: int
403+
394404
def __init__(
395405
self,
396406
text: str,
@@ -401,11 +411,11 @@ def __init__(
401411
inp_tensor = tokenizer.encode(text, return_tensors="pt")
402412

403413
# input tensor into the model of token ids
404-
self.inp_tensor: Tensor = inp_tensor
414+
self.inp_tensor = inp_tensor
405415
# tensor of interpretable token ids
406-
self.itp_tensor: Tensor = inp_tensor
416+
self.itp_tensor = inp_tensor
407417
# interpretable mask
408-
self.itp_mask: Optional[Tensor] = None
418+
self.itp_mask = None
409419

410420
if skip_tokens:
411421
if isinstance(skip_tokens[0], str):
@@ -426,13 +436,11 @@ def __init__(
426436
self.skip_tokens = skip_tokens
427437

428438
# features values, the tokens
429-
self.values: List[str] = tokenizer.convert_ids_to_tokens(
430-
self.itp_tensor[0].tolist()
431-
)
432-
self.tokenizer: TokenizerLike = tokenizer
433-
self.n_itp_features: int = len(self.values)
439+
self.values = tokenizer.convert_ids_to_tokens(self.itp_tensor[0].tolist())
440+
self.tokenizer = tokenizer
441+
self.n_itp_features = len(self.values)
434442

435-
self.baselines: int = (
443+
self.baselines = (
436444
baselines
437445
if type(baselines) is int
438446
else tokenizer.convert_tokens_to_ids([baselines])[0] # type: ignore
@@ -442,8 +450,6 @@ def to_tensor(self) -> torch.Tensor:
442450
# return the perturbation indicator as interpretable tensor instead of token ids
443451
return torch.ones_like(self.itp_tensor)
444452

445-
# pyre-fixme[14]: `to_model_input` overrides method defined in
446-
# `InterpretableInput` inconsistently.
447453
def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Tensor:
448454
if perturbed_tensor is None:
449455
return self.inp_tensor

0 commit comments

Comments
 (0)