|
6 | 6 | from beartype.typing import Union, List, Optional, Tuple |
7 | 7 |
|
8 | 8 | import torch |
9 | | -from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor |
| 9 | +from torch import nn, einsum, Tensor |
10 | 10 | from torch.nn import Module |
11 | 11 | import torch.nn.functional as F |
12 | 12 |
|
|
16 | 16 |
|
17 | 17 | from perfusion_pytorch.open_clip import OpenClipAdapter |
18 | 18 |
|
19 | | -# constants |
20 | | - |
21 | | -IndicesTensor = Union[LongTensor, IntTensor] |
22 | | - |
23 | 19 | # precomputed covariance paths |
24 | 20 | # will add for more models going forward, if the paper checks out |
25 | 21 |
|
@@ -73,9 +69,9 @@ def calculate_input_covariance( |
73 | 69 |
|
74 | 70 | @beartype |
75 | 71 | def loss_fn_weighted_by_mask( |
76 | | - pred: FloatTensor, |
77 | | - target: FloatTensor, |
78 | | - mask: FloatTensor, |
| 72 | + pred: Tensor, |
| 73 | + target: Tensor, |
| 74 | + mask: Tensor, |
79 | 75 | normalized_mask_min_value = 0. |
80 | 76 | ): |
81 | 77 | assert mask.shape[-2:] == pred.shape[-2:] == target.shape[-2:] |
@@ -212,10 +208,10 @@ def parameters(self): |
212 | 208 | @beartype |
213 | 209 | def forward( |
214 | 210 | self, |
215 | | - text_enc: FloatTensor, |
| 211 | + text_enc: Tensor, |
216 | 212 | *, |
217 | | - concept_indices: Optional[IndicesTensor] = None, |
218 | | - text_enc_with_superclass: Optional[FloatTensor] = None, |
| 213 | + concept_indices: Optional[Tensor] = None, |
| 214 | + text_enc_with_superclass: Optional[Tensor] = None, |
219 | 215 | concept_id: Union[int, Tuple[int, ...]] = 0 |
220 | 216 | ): |
221 | 217 | assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}' |
|
0 commit comments