forked from zhuyu-cs/MeanFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
48 lines (37 loc) · 1.41 KB
/
dataset.py
File metadata and controls
48 lines (37 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import lmdb
import pickle
import numpy as np
import os
class LMDBLatentsDataset(torch.utils.data.Dataset):
"""
Args:
lmdb_path (str): LMDB dataset path.
flip_prob (float): flip or upflip.
"""
def __init__(self, lmdb_path, flip_prob=0.5):
self.env = lmdb.open(lmdb_path,
readonly=True,
lock=False,
readahead=False,
meminit=False)
with self.env.begin() as txn:
self.length = int(txn.get('num_samples'.encode()).decode())
self.flip_prob = flip_prob
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin() as txn:
data = txn.get(f'{index}'.encode())
if data is None:
raise IndexError(f'Index {index} is out of bounds')
data = pickle.loads(data)
moments = data['moments']
moments_flip = data['moments_flip']
label = data['label']
use_flip = torch.rand(1).item() < self.flip_prob
moments_to_use = moments_flip if use_flip else moments
moments_tensor = torch.from_numpy(moments_to_use).float()
return moments_tensor, label
def __del__(self):
self.env.close()