Skip to content

Commit 420dbec

Browse files
lingyiyangBenjamin-Walker
authored andcommitted
format as required
1 parent 84a327a commit 420dbec

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

data_dir/dataloaders.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
import pandas as pd
66
import torch
7-
from torch.nn.utils.rnn import pad_sequence
8-
from torch.utils.data import DataLoader, Dataset, random_split
9-
107
from abstract_algebra.finite_algebras import (
118
FiniteAlgebra,
129
generate_cyclic_group,
1310
generate_symmetric_group,
1411
)
12+
from torch.nn.utils.rnn import pad_sequence
13+
from torch.utils.data import DataLoader, Dataset, random_split
1514

1615

1716
def generate_group(g: (str, int)) -> FiniteAlgebra:
@@ -190,8 +189,9 @@ def __getitem__(self, idx):
190189

191190

192191
class GroupCompositionDataset(Dataset):
193-
def __init__(self, group='A5', min_length=3, max_length=20, num_samples=1024, seed=1234):
194-
192+
def __init__(
193+
self, group="A5", min_length=3, max_length=20, num_samples=1024, seed=1234
194+
):
195195
super().__init__()
196196
random.seed(seed)
197197
self.seeds = [random.randint(0, 2**32 - 1) for _ in range(num_samples)]
@@ -207,13 +207,10 @@ def __init__(self, group='A5', min_length=3, max_length=20, num_samples=1024, se
207207
self.data_dim = self.group_size
208208
self.label_dim = self.group_size
209209

210-
211210
def __len__(self):
212211
return self.num_samples
213212

214-
215213
def __getitem__(self, idx):
216-
217214
rng = random.Random(self.seeds[idx])
218215
length = rng.randint(self.min_length, self.max_length)
219216

@@ -344,7 +341,6 @@ def create_group_dataloaders(
344341
train_split: float = 0.8,
345342
seed: int = 1234,
346343
):
347-
348344
dataset = GroupCompositionDataset(group, min_length, max_length, num_samples, seed)
349345

350346
def col_fn(batch):
@@ -358,10 +354,28 @@ def col_fn(batch):
358354
[train_size, test_size],
359355
generator=torch.Generator().manual_seed(seed),
360356
)
361-
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=col_fn, num_workers=0)
362-
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=col_fn, num_workers=0)
357+
train_loader = DataLoader(
358+
train_set,
359+
batch_size=batch_size,
360+
shuffle=True,
361+
collate_fn=col_fn,
362+
num_workers=0,
363+
)
364+
test_loader = DataLoader(
365+
test_set,
366+
batch_size=batch_size,
367+
shuffle=False,
368+
collate_fn=col_fn,
369+
num_workers=0,
370+
)
363371
else:
364-
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=col_fn, num_workers=0)
372+
train_loader = DataLoader(
373+
dataset,
374+
batch_size=batch_size,
375+
shuffle=True,
376+
collate_fn=col_fn,
377+
num_workers=0,
378+
)
365379
test_loader = None
366380

367381
data_dim = len(dataset.group.elements)

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def train_dataloader_multilength():
329329
yield (X, X_2), (y, y_2), (mask, mask_2)
330330

331331
dataloader = {"train": train_dataloader_multilength(), "val": val_dataloader}
332-
332+
333333
elif task == "A5_generalise":
334334
train_padding_length = 128
335335
if model_name == "lcde":

0 commit comments

Comments
 (0)