Skip to content

Commit 4b2cf6d

Browse files
committed
enforce some types on forward
1 parent 98a1bca commit 4b2cf6d

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
2-
from torch import nn, einsum, Tensor
2+
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor
33
from torch.nn import Module
44
import torch.nn.functional as F
55

66
from beartype import beartype
7+
from beartype.typing import Union
78
from einops import rearrange
89

910
# helpers
@@ -54,9 +55,9 @@ def __init__(
5455
@beartype
5556
def forward(
5657
self,
57-
text_enc: Tensor,
58-
text_enc_with_superclass: Tensor,
59-
concept_indices: Tensor
58+
text_enc: FloatTensor,
59+
text_enc_with_superclass: FloatTensor,
60+
concept_indices: Union[IntTensor, LongTensor]
6061
):
6162
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
6263

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.3',
6+
version = '0.0.4',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)