Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit 8367b39

Browse files
authored
Fix typo (#935)
1 parent cbd117c commit 8367b39

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/mnist/train_ray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def create_train_state(rng, config):
118118
apply_fn=cnn.apply, params=params, tx=tx)
119119

120120

121-
def get_train_data_laoder(train_ds, state, batch_size):
121+
def get_train_data_loader(train_ds, state, batch_size):
122122
images_np = train_ds['image']
123123
labels_np = train_ds['label']
124124
steps_per_epoch = len(images_np) // batch_size
@@ -163,7 +163,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict,
163163
rng = jax.random.PRNGKey(0)
164164
state = create_train_state(rng, config)
165165

166-
train_data_loader, steps_per_epoch = get_train_data_laoder(
166+
train_data_loader, steps_per_epoch = get_train_data_loader(
167167
train_ds, state, config.batch_size)
168168

169169
for epoch in range(1, config.num_epochs + 1):

0 commit comments

Comments
 (0)