Skip to content

Commit 0879e68

Browse files
committed
finish merge
2 parents 7edb702 + 0dc16db commit 0879e68

File tree

4 files changed

+37
-31
lines changed

4 files changed

+37
-31
lines changed

algoperf/workloads/lm/input_pipeline.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def batch_with_padding(
5454
def get_data_iter(data_rng: jax.random.PRNGKey,
5555
split: str,
5656
data_dir: str,
57-
global_batch_size: int,
57+
batch_size: int,
5858
num_batches: Optional[int] = None,):
5959

60-
ds = get_lm_dataset(data_rng, split, data_dir, global_batch_size, num_batches)
60+
ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches)
6161

6262
it = map(
6363
functools.partial(
64-
data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
64+
data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size
6565
),
6666
ds,
6767
)
@@ -72,7 +72,7 @@ def get_lm_dataset(
7272
data_rng: jax.random.PRNGKey,
7373
split: str,
7474
data_dir: str,
75-
global_batch_size: int,
75+
batch_size: int,
7676
num_batches: Optional[int] = None,
7777
):
7878
"""Load preprocessed TF dataset."""
@@ -98,14 +98,15 @@ def get_lm_dataset(
9898
},
9999
num_parallel_calls=AUTOTUNE,
100100
)
101-
sequences_ds = sequences_ds.repeat()
102101
if split == 'train':
103102
ds = sequences_ds.shuffle(
104103
SHUFFLE_BUFFER_SIZE, seed=shuffle_seed
105104
)
106105
ds = ds.batch(
107-
global_batch_size, drop_remainder=False
106+
batch_size, drop_remainder=False
108107
)
108+
ds = ds.take(num_batches) if num_batches is not None else ds
109+
ds = ds.repeat()
109110
ds = ds.map(lambda x: {
110111
'inputs': x['inputs'],
111112
'targets': x['targets'],
@@ -115,12 +116,14 @@ def get_lm_dataset(
115116
elif split == 'eval_train':
116117
ds = batch_with_padding(
117118
sequences_ds,
118-
global_batch_size,
119+
batch_size,
119120
padded_shapes={
120-
'inputs': (global_batch_size, None),
121-
'targets': (global_batch_size, None),
121+
'inputs': (batch_size, None),
122+
'targets': (batch_size, None),
122123
},
123124
)
125+
ds = ds.take(num_batches) if num_batches is not None else ds
126+
ds = ds.repeat()
124127
ds = ds.map(lambda x: {'inputs': x['inputs'],
125128
'targets': x['targets'],
126129
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)
@@ -129,12 +132,14 @@ def get_lm_dataset(
129132
elif split == 'validation':
130133
ds = batch_with_padding(
131134
sequences_ds,
132-
global_batch_size,
135+
batch_size,
133136
padded_shapes={
134-
'inputs': (global_batch_size, None),
135-
'targets': (global_batch_size, None),
137+
'inputs': (batch_size, None),
138+
'targets': (batch_size, None),
136139
},
137140
)
141+
ds = ds.take(num_batches) if num_batches is not None else ds
142+
ds = ds.repeat()
138143
ds = ds.map(lambda x: {'inputs': x['inputs'],
139144
'targets': x['targets'],
140145
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@ def _build_input_queue(self,
2121
split: str,
2222
data_dir: str,
2323
global_batch_size: int,
24-
num_batches: Optional[int] = None,
25-
repeat_final_dataset: bool = False):
24+
cache: Optional[bool] = None,
25+
repeat_final_dataset: Optional[bool] = None,
26+
num_batches: Optional[int] = None):
2627
"""Build an input queue using pre-cached FineWeb dataset."""
27-
del num_batches
28-
del repeat_final_dataset
28+
del cache, repeat_final_dataset
2929
ds = get_data_iter(
3030
data_rng=data_rng,
3131
split=split,
3232
data_dir=data_dir,
33-
global_batch_size=global_batch_size)
33+
batch_size=global_batch_size,
34+
num_batches=num_batches)
3435
ds = map(jax_sharding_utils.shard_along_batch_dim, ds)
3536
return ds
3637

@@ -71,7 +72,7 @@ def model_fn(
7172
mode: spec.ForwardPassMode,
7273
rng: spec.RandomState,
7374
update_batch_norm: bool,
74-
dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
75+
dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
7576
del mode, rng, update_batch_norm, model_state, dropout_rate
7677
inputs = batch['inputs']
7778
# Convert one-hot inputs to token IDs if needed

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,17 @@ def _build_input_queue(
9898
split: str,
9999
data_dir: str,
100100
global_batch_size: int,
101-
num_batches: Optional[int] = None,
102-
repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
101+
cache: Optional[bool] = None,
102+
repeat_final_dataset: Optional[bool] = None,
103+
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
103104
"""Build an input queue for the given split."""
105+
del cache, repeat_final_dataset
104106
local_batch_size = global_batch_size // N_GPUS
105-
# In DDP mode, pass local_device_count=1 to prevent shard_and_maybe_pad_np
106-
# from seeing all GPUs via torch.cuda.device_count()
107107
loader = get_data_iter(
108108
data_rng=data_rng,
109109
split=split,
110110
data_dir=data_dir,
111-
global_batch_size=local_batch_size,
111+
batch_size=local_batch_size,
112112
num_batches=num_batches,
113113
)
114114
if USE_PYTORCH_DDP:

algoperf/workloads/lm/workload.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import abc
44
import math
5+
import numpy as np
56
import os
6-
from typing import Any, Dict, Optional
7+
from typing import Any, Dict, Optional, Iterator
78

89
import jax
9-
import numpy as np
1010
from absl import flags
1111

1212
from algoperf import spec
@@ -85,11 +85,11 @@ def train_stddev(self):
8585

8686
@property
8787
def max_allowed_runtime_sec(self) -> int:
88-
return 3600 * 14 # 14 hours
88+
return 3600 * 14 # 14 hours TODO(kasimbeg): update
8989

9090
@property
9191
def eval_period_time_sec(self) -> int:
92-
return 1200 # 20 minutes
92+
return 1200 # 20 minutes TODO(kasimbeg): update
9393

9494
@property
9595
def step_hint(self) -> int:
@@ -119,9 +119,10 @@ def _build_input_queue(
119119
split: str,
120120
data_dir: str,
121121
global_batch_size: int,
122+
cache: Optional[bool] = None,
123+
repeat_final_dataset: Optional[bool] = None,
122124
num_batches: Optional[int] = None,
123-
repeat_final_dataset: bool = False,
124-
):
125+
) -> Iterator[Dict[str, Any]]:
125126
"""Build an input queue for the given split."""
126127

127128

@@ -150,8 +151,7 @@ def _eval_model_on_split(
150151
split,
151152
data_dir,
152153
global_batch_size,
153-
num_batches,
154-
repeat_final_dataset=True,
154+
num_batches=num_batches
155155
)
156156

157157
eval_metrics = {}

0 commit comments

Comments
 (0)