Skip to content

Commit 6b9dac5

Browse files
committed
fix: reverted val set
1 parent 4de2b60 commit 6b9dac5

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

flaxdiff/data/dataloaders.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -338,28 +338,28 @@ def get_trainset():
338338
)
339339
return loader
340340

341-
def get_valset():
342-
transformations = [
343-
augmenter(),
344-
pygrain.Batch(local_batch_size, drop_remainder=True),
345-
]
341+
# def get_valset():
342+
# transformations = [
343+
# augmenter(),
344+
# pygrain.Batch(local_batch_size, drop_remainder=True),
345+
# ]
346346

347-
loader = pygrain.DataLoader(
348-
data_source=train_source,
349-
sampler=train_sampler,
350-
operations=transformations,
351-
worker_count=2,
352-
read_options=pygrain.ReadOptions(
353-
read_thread_count, read_buffer_size
354-
),
355-
worker_buffer_size=2,
356-
)
357-
return loader
347+
# loader = pygrain.DataLoader(
348+
# data_source=train_source,
349+
# sampler=train_sampler,
350+
# operations=transformations,
351+
# worker_count=2,
352+
# read_options=pygrain.ReadOptions(
353+
# read_thread_count, read_buffer_size
354+
# ),
355+
# worker_buffer_size=2,
356+
# )
357+
# return loader
358358

359359
return {
360360
"train": get_trainset,
361361
"train_len": len(train_source),
362-
"val": get_valset,
362+
"val": get_trainset,
363363
"val_len": len(train_source),
364364
"local_batch_size": local_batch_size,
365365
"global_batch_size": batch_size,

0 commit comments

Comments
 (0)