Skip to content

Commit 2ac7b49

Browse files
committed
just do tensor for typehint
1 parent 89d39d8 commit 2ac7b49

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from beartype.typing import Union, List, Optional, Tuple
77

88
import torch
9-
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor
9+
from torch import nn, einsum, Tensor
1010
from torch.nn import Module
1111
import torch.nn.functional as F
1212

@@ -16,10 +16,6 @@
1616

1717
from perfusion_pytorch.open_clip import OpenClipAdapter
1818

19-
# constants
20-
21-
IndicesTensor = Union[LongTensor, IntTensor]
22-
2319
# precomputed covariance paths
2420
# will add for more models going forward, if the paper checks out
2521

@@ -73,9 +69,9 @@ def calculate_input_covariance(
7369

7470
@beartype
7571
def loss_fn_weighted_by_mask(
76-
pred: FloatTensor,
77-
target: FloatTensor,
78-
mask: FloatTensor,
72+
pred: Tensor,
73+
target: Tensor,
74+
mask: Tensor,
7975
normalized_mask_min_value = 0.
8076
):
8177
assert mask.shape[-2:] == pred.shape[-2:] == target.shape[-2:]
@@ -212,10 +208,10 @@ def parameters(self):
212208
@beartype
213209
def forward(
214210
self,
215-
text_enc: FloatTensor,
211+
text_enc: Tensor,
216212
*,
217-
concept_indices: Optional[IndicesTensor] = None,
218-
text_enc_with_superclass: Optional[FloatTensor] = None,
213+
concept_indices: Optional[Tensor] = None,
214+
text_enc_with_superclass: Optional[Tensor] = None,
219215
concept_id: Union[int, Tuple[int, ...]] = 0
220216
):
221217
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.21',
6+
version = '0.1.22',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)