Skip to content

Commit 2bd58cc

Browse files
committed
Add minibatch size to speed up training when validation set is big
1 parent 05fa295 commit 2bd58cc

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ Think Keras for Language Models applications, a clean, declarative API where:
6868
| 🏢 **Data Scientists** | Integrate LM workflows with APIs & databases. |
6969
| 🎓 **Students/Hobbyists** | Learn AI composition in a clean, intuitive framework. |
7070

71-
No ML background required, just Python skills.
72-
7371
</div>
7472

7573
## Why Synalinks?

coverage-badge.svg

Lines changed: 1 addition & 1 deletion
Loading

synalinks/src/trainers/trainer.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ async def fit(
330330
x=None,
331331
y=None,
332332
batch_size=1,
333+
minibatch_size=4,
333334
epochs=1,
334335
verbose="auto",
335336
callbacks=None,
@@ -360,6 +361,10 @@ async def fit(
360361
If unspecified, `batch_size` will default to 32.
361362
Do not specify the `batch_size` if your input data `x` is a
362363
Python generator function since they generate batches.
364+
minibatch_size (int): Integer or `None`.
365+
Number of randomly selected samples per batch validation.
366+
If unspecified, `minibatch_size` will default to 4.
367+
If `None`, the whole validation set will be used.
363368
epochs (int): Integer. Number of epochs to train the program.
364369
An epoch is an iteration over the entire `x` and `y`
365370
data provided (unless the `steps_per_epoch` flag is set to
@@ -532,13 +537,25 @@ async def fit(
532537
)
533538

534539
callbacks.on_train_batch_begin(step)
540+
541+
mini_val_x = None
542+
mini_val_y = None
543+
if minibatch_size:
544+
if len(val_x) > minibatch_size:
545+
indices = np.random.choice(
546+
len(val_x),
547+
size=minibatch_size,
548+
replace=False,
549+
)
550+
mini_val_x = val_x[indices]
551+
mini_val_y = val_y[indices]
535552

536553
logs = await self.train_on_batch(
537554
step=step,
538555
x=x_batch,
539556
y=y_batch,
540-
val_x=val_x,
541-
val_y=val_y,
557+
val_x=mini_val_x if mini_val_x else val_x,
558+
val_y=mini_val_y if mini_val_y else val_y,
542559
return_dict=True,
543560
)
544561

0 commit comments

Comments
 (0)