Skip to content

Commit 5c0e3eb

Browse files
committed
feat: use default precision
1 parent 3c7fff0 commit 5c0e3eb

File tree

3 files changed

+75
-29
lines changed

3 files changed

+75
-29
lines changed

flaxdiff/data/dataloaders.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -336,27 +336,29 @@ def get_trainset():
336336
)
337337
return loader
338338

339-
# def get_valset():
340-
# transformations = [
341-
# augmenter(),
342-
# pygrain.Batch(local_batch_size, drop_remainder=True),
343-
# ]
339+
def get_valset():
340+
transformations = [
341+
augmenter(),
342+
pygrain.Batch(local_batch_size, drop_remainder=True),
343+
]
344344

345-
# loader = pygrain.DataLoader(
346-
# data_source=data_source,
347-
# sampler=val_sampler,
348-
# operations=transformations,
349-
# worker_count=worker_count,
350-
# read_options=pygrain.ReadOptions(
351-
# read_thread_count, read_buffer_size
352-
# ),
353-
# worker_buffer_size=worker_buffer_size,
354-
# )
355-
# return loader
345+
loader = pygrain.DataLoader(
346+
data_source=data_source,
347+
sampler=train_sampler,
348+
operations=transformations,
349+
worker_count=4,
350+
read_options=pygrain.ReadOptions(
351+
5, 100
352+
),
353+
worker_buffer_size=5,
354+
)
355+
return loader
356356

357357
return {
358358
"train": get_trainset,
359359
"train_len": len(data_source),
360+
"val": get_valset,
361+
"val_len": len(data_source),
360362
"local_batch_size": local_batch_size,
361363
"global_batch_size": batch_size,
362364
}

0 commit comments

Comments
 (0)