Skip to content

Commit 46e50db

Browse files
authored
Add split functionality to synthetic_data
1 parent e982248 commit 46e50db

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

cebra/datasets/synthetic_data.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,25 @@ def __init__(self, name, root=_DEFAULT_DATADIR, download=True):
112112
self.index = self.data['u']
113113
self.lam = self.data['lam']
114114

115+
116+
def split(self, split):
117+
tot_len = len(self.neural)
118+
train_idx = np.arange(tot_len)[:int(tot_len*0.8)]
119+
valid_idx = np.arange(tot_len)[int(tot_len*0.8):]
120+
121+
if split == 'train':
122+
self.neural = self.neural[train_idx]
123+
self.index = self.index[train_idx]
124+
self.idx = train_idx
125+
elif split == 'valid':
126+
self.neural = self.neural[valid_idx]
127+
self.index = self.index[valid_idx]
128+
self.idx = valid_idx
129+
elif split == 'all':
130+
pass
131+
else:
132+
raise ValueError(f"{split} not supported")
133+
115134
@property
116135
def input_dimension(self):
117136
return self.neural.size(1)

0 commit comments

Comments
 (0)