-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
29 lines (23 loc) · 984 Bytes
/
dataset.py
File metadata and controls
29 lines (23 loc) · 984 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
class PNGDataset(Dataset):
def __init__(self, image_dir):
self.image_paths = glob.glob(os.path.join(image_dir, "*.png"))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
# Load the image as a 16-bit grayscale image
image = Image.open(img_path).convert("I;16")
# Normalize the image to [0, 1] range
image = np.array(image) / 65535.0 # Normalize 16-bit range to [0, 1]
image = torch.tensor(image, dtype=torch.float32).unsqueeze(0) # Add channel dimension
# Generate a random timestep (or define a method to load specific timesteps)
timestep = torch.randint(0, 1000, (1,)).item() # Assuming timesteps are in the range [0, 999]
return image, timestep
if __name__ == "__main__":
pass