|
1 | 1 | import torch |
2 | | -from torch.utils.data import TensorDataset |
| 2 | +from torch.utils.data import TensorDataset, random_split |
3 | 3 | import torch.nn.functional as F |
4 | 4 | import numpy as np |
5 | 5 | import beaupy |
|
16 | 16 |
|
17 | 17 |
|
18 | 18 | def load_data(n=10000, split_ratio=0.8, seed=42): |
19 | | - # Fix Seed |
| 19 | + # Fix random seed for reproducibility |
20 | 20 | torch.manual_seed(seed) |
21 | 21 |
|
22 | | - x = torch.linspace(0, 1, n) + torch.rand(n) * 0.01 |
23 | | - y = torch.cos(x * (2 * pi)) + torch.rand(n) * 0.01 |
| 22 | + x_noise = torch.rand(n) * 0.02 |
| 23 | + x = torch.linspace(0, 1, n) + x_noise |
| 24 | + x = x.clamp(0, 1) # Fix x to be in [0, 1] |
24 | 25 |
|
25 | | - ics = torch.randperm(n) |
26 | | - ics_train = ics[: int(n * split_ratio)] |
27 | | - ics_val = ics[int(n * split_ratio) :] |
| 26 | + noise_level = 0.05 |
| 27 | + y = ( |
| 28 | + 1.0 * torch.sin(4 * pi * x) |
| 29 | + + 0.5 * torch.sin(10 * pi * x) |
| 30 | + + 1.5 * (x**2) |
| 31 | + + torch.randn(n) * noise_level |
| 32 | + ) |
| 33 | + |
| 34 | + x = x.view(-1, 1) |
| 35 | + y = y.view(-1, 1) |
28 | 36 |
|
29 | | - x_train = x[ics_train].view(-1, 1) |
30 | | - y_train = y[ics_train].view(-1, 1) |
31 | | - x_val = x[ics_val].view(-1, 1) |
32 | | - y_val = y[ics_val].view(-1, 1) |
| 37 | + full_dataset = TensorDataset(x, y) |
33 | 38 |
|
34 | | - train_ds = TensorDataset(x_train, y_train) |
35 | | - val_ds = TensorDataset(x_val, y_val) |
| 39 | + train_size = int(n * split_ratio) |
| 40 | + val_size = n - train_size |
| 41 | + |
| 42 | + generator = torch.Generator().manual_seed(seed) |
| 43 | + train_dataset, val_dataset = random_split( |
| 44 | + full_dataset, [train_size, val_size], generator=generator |
| 45 | + ) |
36 | 46 |
|
37 | | - return train_ds, val_ds |
| 47 | + return train_dataset, val_dataset |
38 | 48 |
|
39 | 49 |
|
40 | 50 | def set_seed(seed: int): |
|
0 commit comments