-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdebug_dataset.py
More file actions
27 lines (20 loc) · 919 Bytes
/
debug_dataset.py
File metadata and controls
27 lines (20 loc) · 919 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from PIL import Image
from torchvision.datasets import CIFAR10
class CIFAR10CLIPDataset(CIFAR10):
def __init__(self, train, root, transform, tokenizer, return_labels=False):
super().__init__(root=root, train=train, download=True, transform=None)
self.clip_transform = transform
self.tokenizer = tokenizer
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
self.return_labels = return_labels
def __getitem__(self, index):
image, label_idx = super().__getitem__(index)
if self.clip_transform:
image = self.clip_transform(image)
if self.return_labels:
return image, label_idx
class_name = self.classes[label_idx]
caption = f"a photo of a {class_name}"
tokens = self.tokenizer(caption)
return image, tokens