44
55import pandas as pd
66import torch
7- from torch .nn .utils .rnn import pad_sequence
8- from torch .utils .data import DataLoader , Dataset , random_split
9-
107from abstract_algebra .finite_algebras import (
118 FiniteAlgebra ,
129 generate_cyclic_group ,
1310 generate_symmetric_group ,
1411)
12+ from torch .nn .utils .rnn import pad_sequence
13+ from torch .utils .data import DataLoader , Dataset , random_split
1514
1615
1716def generate_group (g : (str , int )) -> FiniteAlgebra :
@@ -190,8 +189,9 @@ def __getitem__(self, idx):
190189
191190
192191class GroupCompositionDataset (Dataset ):
193- def __init__ (self , group = 'A5' , min_length = 3 , max_length = 20 , num_samples = 1024 , seed = 1234 ):
194-
192+ def __init__ (
193+ self , group = "A5" , min_length = 3 , max_length = 20 , num_samples = 1024 , seed = 1234
194+ ):
195195 super ().__init__ ()
196196 random .seed (seed )
197197 self .seeds = [random .randint (0 , 2 ** 32 - 1 ) for _ in range (num_samples )]
@@ -207,13 +207,10 @@ def __init__(self, group='A5', min_length=3, max_length=20, num_samples=1024, se
207207 self .data_dim = self .group_size
208208 self .label_dim = self .group_size
209209
210-
211210 def __len__ (self ):
212211 return self .num_samples
213212
214-
215213 def __getitem__ (self , idx ):
216-
217214 rng = random .Random (self .seeds [idx ])
218215 length = rng .randint (self .min_length , self .max_length )
219216
@@ -344,7 +341,6 @@ def create_group_dataloaders(
344341 train_split : float = 0.8 ,
345342 seed : int = 1234 ,
346343):
347-
348344 dataset = GroupCompositionDataset (group , min_length , max_length , num_samples , seed )
349345
350346 def col_fn (batch ):
@@ -358,10 +354,28 @@ def col_fn(batch):
358354 [train_size , test_size ],
359355 generator = torch .Generator ().manual_seed (seed ),
360356 )
361- train_loader = DataLoader (train_set , batch_size = batch_size , shuffle = True , collate_fn = col_fn , num_workers = 0 )
362- test_loader = DataLoader (test_set , batch_size = batch_size , shuffle = False , collate_fn = col_fn , num_workers = 0 )
357+ train_loader = DataLoader (
358+ train_set ,
359+ batch_size = batch_size ,
360+ shuffle = True ,
361+ collate_fn = col_fn ,
362+ num_workers = 0 ,
363+ )
364+ test_loader = DataLoader (
365+ test_set ,
366+ batch_size = batch_size ,
367+ shuffle = False ,
368+ collate_fn = col_fn ,
369+ num_workers = 0 ,
370+ )
363371 else :
364- train_loader = DataLoader (dataset , batch_size = batch_size , shuffle = True , collate_fn = col_fn , num_workers = 0 )
372+ train_loader = DataLoader (
373+ dataset ,
374+ batch_size = batch_size ,
375+ shuffle = True ,
376+ collate_fn = col_fn ,
377+ num_workers = 0 ,
378+ )
365379 test_loader = None
366380
367381 data_dim = len (dataset .group .elements )
0 commit comments