Skip to content

Commit 337ffbe

Browse files
committed
Fix load_data
1 parent ee3e6af commit 337ffbe

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

util.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torch.utils.data import TensorDataset
2+
from torch.utils.data import TensorDataset, random_split
33
import torch.nn.functional as F
44
import numpy as np
55
import beaupy
@@ -16,25 +16,35 @@
1616

1717

1818
def load_data(n=10000, split_ratio=0.8, seed=42):
19-
# Fix Seed
19+
# Fix random seed for reproducibility
2020
torch.manual_seed(seed)
2121

22-
x = torch.linspace(0, 1, n) + torch.rand(n) * 0.01
23-
y = torch.cos(x * (2 * pi)) + torch.rand(n) * 0.01
22+
x_noise = torch.rand(n) * 0.02
23+
x = torch.linspace(0, 1, n) + x_noise
24+
x = x.clamp(0, 1) # Fix x to be in [0, 1]
2425

25-
ics = torch.randperm(n)
26-
ics_train = ics[: int(n * split_ratio)]
27-
ics_val = ics[int(n * split_ratio) :]
26+
noise_level = 0.05
27+
y = (
28+
1.0 * torch.sin(4 * pi * x)
29+
+ 0.5 * torch.sin(10 * pi * x)
30+
+ 1.5 * (x**2)
31+
+ torch.randn(n) * noise_level
32+
)
33+
34+
x = x.view(-1, 1)
35+
y = y.view(-1, 1)
2836

29-
x_train = x[ics_train].view(-1, 1)
30-
y_train = y[ics_train].view(-1, 1)
31-
x_val = x[ics_val].view(-1, 1)
32-
y_val = y[ics_val].view(-1, 1)
37+
full_dataset = TensorDataset(x, y)
3338

34-
train_ds = TensorDataset(x_train, y_train)
35-
val_ds = TensorDataset(x_val, y_val)
39+
train_size = int(n * split_ratio)
40+
val_size = n - train_size
41+
42+
generator = torch.Generator().manual_seed(seed)
43+
train_dataset, val_dataset = random_split(
44+
full_dataset, [train_size, val_size], generator=generator
45+
)
3646

37-
return train_ds, val_ds
47+
return train_dataset, val_dataset
3848

3949

4050
def set_seed(seed: int):

0 commit comments

Comments
 (0)