Skip to content

Commit 0ad09c4

Browse files
committed
allow channels to be customizable for cvt
1 parent 92b6932 commit 0ad09c4

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.6.2',
6+
version = '1.6.3',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/cvt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,13 @@ def __init__(
140140
s3_heads = 6,
141141
s3_depth = 10,
142142
s3_mlp_mult = 4,
143-
dropout = 0.
143+
dropout = 0.,
144+
channels = 3
144145
):
145146
super().__init__()
146147
kwargs = dict(locals())
147148

148-
dim = 3
149+
dim = channels
149150
layers = []
150151

151152
for prefix in ('s1', 's2', 's3'):

0 commit comments

Comments
 (0)