1
1
# pyre-strict
2
2
from abc import ABC , abstractmethod
3
- from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
3
+ from typing import Callable , cast , Dict , List , Optional , Tuple , Union
4
4
5
5
import torch
6
6
@@ -116,17 +116,19 @@ def to_tensor(self) -> Tensor:
116
116
pass
117
117
118
118
@abstractmethod
119
- # pyre-fixme[3]: Return annotation cannot be `Any`.
120
- def to_model_input (self , itp_tensor : Optional [Tensor ] = None ) -> Any :
119
+ def to_model_input (
120
+ self , perturbed_tensor : Optional [Tensor ] = None
121
+ ) -> Union [str , Tensor ]:
121
122
"""
122
123
Get the (perturbed) input in the format required by the model
123
124
based on the given (perturbed) interpretable representation.
124
125
125
126
Args:
126
127
127
- itp_tensor (Tensor, optional): tensor of the interpretable representation
128
- of this input. If it is None, assume the interpretable
129
- representation is pristine and return the original model input
128
+ perturbed_tensor (Tensor, optional): tensor of the interpretable
129
+ representation of this input. If it is None, assume the
130
+ interpretable representation is pristine and return the
131
+ original model input
130
132
Default: None.
131
133
132
134
@@ -198,13 +200,25 @@ class TextTemplateInput(InterpretableInput):
198
200
199
201
"""
200
202
203
+ values : List [str ]
204
+ dict_keys : List [str ]
205
+ baselines : Union [List [str ], Callable [[], Union [List [str ], Dict [str , str ]]]]
206
+ n_features : int
207
+ n_itp_features : int
208
+ format_fn : Callable [..., str ]
209
+ mask : Union [List [int ], Dict [str , int ], None ]
210
+ formatted_mask : List [int ]
211
+
201
212
def __init__ (
202
213
self ,
203
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
204
- template : Union [str , Callable ],
214
+ template : Union [str , Callable [..., str ]],
205
215
values : Union [List [str ], Dict [str , str ]],
206
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
207
- baselines : Union [List [str ], Dict [str , str ], Callable , None ] = None ,
216
+ baselines : Union [
217
+ List [str ],
218
+ Dict [str , str ],
219
+ Callable [[], Union [List [str ], Dict [str , str ]]],
220
+ None ,
221
+ ] = None ,
208
222
mask : Union [List [int ], Dict [str , int ], None ] = None ,
209
223
) -> None :
210
224
# convert values dict to list
@@ -217,8 +231,8 @@ def __init__(
217
231
), f"the values must be either a list or a dict, received: { type (values )} "
218
232
dict_keys = []
219
233
220
- self .values : List [ str ] = values
221
- self .dict_keys : List [ str ] = dict_keys
234
+ self .values = values
235
+ self .dict_keys = dict_keys
222
236
223
237
n_features = len (values )
224
238
@@ -261,15 +275,12 @@ def __init__(
261
275
262
276
# internal compressed mask of continuous interpretable indices from 0
263
277
# cannot replace original mask of ids for grouping across values externally
264
- # pyre-fixme[4]: Attribute must be annotated.
265
278
self .formatted_mask = [mask_id_to_idx [mid ] for mid in mask ]
266
279
267
280
n_itp_features = len (mask_ids )
268
281
269
282
# number of raw features and intepretable features
270
- # pyre-fixme[4]: Attribute must be annotated.
271
283
self .n_features = n_features
272
- # pyre-fixme[4]: Attribute must be annotated.
273
284
self .n_itp_features = n_itp_features
274
285
275
286
if isinstance (template , str ):
@@ -280,7 +291,6 @@ def __init__(
280
291
f"received: { type (template )} "
281
292
)
282
293
template = template
283
- # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
284
294
self .format_fn = template
285
295
286
296
self .mask = mask
@@ -289,8 +299,6 @@ def to_tensor(self) -> torch.Tensor:
289
299
# Interpretable representation in shape(1, n_itp_features)
290
300
return torch .tensor ([[1.0 ] * self .n_itp_features ])
291
301
292
- # pyre-fixme[14]: `to_model_input` overrides method defined in
293
- # `InterpretableInput` inconsistently.
294
302
def to_model_input (self , perturbed_tensor : Optional [Tensor ] = None ) -> str :
295
303
values = list (self .values ) # clone
296
304
@@ -321,18 +329,12 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str:
321
329
itp_val = perturbed_tensor [0 ][itp_idx ]
322
330
323
331
if not itp_val :
324
- # pyre-fixme[16]: Item `None` of `Union[None, Dict[str, str],
325
- # List[typing.Any]]` has no attribute `__getitem__`.
326
332
values [i ] = baselines [i ]
327
333
328
334
if self .dict_keys :
329
335
dict_values = dict (zip (self .dict_keys , values ))
330
- # pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not
331
- # a function.
332
336
input_str = self .format_fn (** dict_values )
333
337
else :
334
- # pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not
335
- # a function.
336
338
input_str = self .format_fn (* values )
337
339
338
340
return input_str
@@ -391,6 +393,14 @@ class TextTokenInput(InterpretableInput):
391
393
392
394
"""
393
395
396
+ inp_tensor : Tensor
397
+ itp_tensor : Tensor
398
+ itp_mask : Optional [Tensor ]
399
+ values : List [str ]
400
+ tokenizer : TokenizerLike
401
+ n_itp_features : int
402
+ baselines : int
403
+
394
404
def __init__ (
395
405
self ,
396
406
text : str ,
@@ -401,11 +411,11 @@ def __init__(
401
411
inp_tensor = tokenizer .encode (text , return_tensors = "pt" )
402
412
403
413
# input tensor into the model of token ids
404
- self .inp_tensor : Tensor = inp_tensor
414
+ self .inp_tensor = inp_tensor
405
415
# tensor of interpretable token ids
406
- self .itp_tensor : Tensor = inp_tensor
416
+ self .itp_tensor = inp_tensor
407
417
# interpretable mask
408
- self .itp_mask : Optional [ Tensor ] = None
418
+ self .itp_mask = None
409
419
410
420
if skip_tokens :
411
421
if isinstance (skip_tokens [0 ], str ):
@@ -426,13 +436,11 @@ def __init__(
426
436
self .skip_tokens = skip_tokens
427
437
428
438
# features values, the tokens
429
- self .values : List [str ] = tokenizer .convert_ids_to_tokens (
430
- self .itp_tensor [0 ].tolist ()
431
- )
432
- self .tokenizer : TokenizerLike = tokenizer
433
- self .n_itp_features : int = len (self .values )
439
+ self .values = tokenizer .convert_ids_to_tokens (self .itp_tensor [0 ].tolist ())
440
+ self .tokenizer = tokenizer
441
+ self .n_itp_features = len (self .values )
434
442
435
- self .baselines : int = (
443
+ self .baselines = (
436
444
baselines
437
445
if type (baselines ) is int
438
446
else tokenizer .convert_tokens_to_ids ([baselines ])[0 ] # type: ignore
@@ -442,8 +450,6 @@ def to_tensor(self) -> torch.Tensor:
442
450
# return the perturbation indicator as interpretable tensor instead of token ids
443
451
return torch .ones_like (self .itp_tensor )
444
452
445
- # pyre-fixme[14]: `to_model_input` overrides method defined in
446
- # `InterpretableInput` inconsistently.
447
453
def to_model_input (self , perturbed_tensor : Optional [Tensor ] = None ) -> Tensor :
448
454
if perturbed_tensor is None :
449
455
return self .inp_tensor
0 commit comments