We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5a3b333 commit de85c09Copy full SHA for de85c09
vision_transformer/main.py
@@ -49,9 +49,9 @@ def __init__(self, args):
49
# Linear projection
50
self.LinearProjection = nn.Linear(self.input_size, self.latent_size)
51
# Class token
52
- self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
+ self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))
53
# Positional embedding
54
- self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
+ self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))
55
56
def forward(self, input_data):
57
input_data = input_data.to(self.device)
0 commit comments