Skip to content

Commit cd7be3d

Browse files
committed
Revert "Add stopwatch timer to train.py (#208)"
This reverts commit 0eee152.
1 parent 0eee152 commit cd7be3d

File tree

15 files changed

+44
-172
lines changed

15 files changed

+44
-172
lines changed

scripts/all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from zeroband.collectives import Compression, all_reduce
77
from zeroband.utils.world_info import get_world_info
8-
from zeroband.utils.logger import get_logger
8+
from zeroband.utils.logging import get_logger
99

1010
from enum import Enum
1111

scripts/convert_dl_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from zeroband.data import get_dataloader
99
from transformers import AutoTokenizer
1010
from zeroband.train import Config
11-
from zeroband.utils.logger import get_logger
11+
from zeroband.utils.logging import get_logger
1212
from pydantic_config import parse_argv
1313

1414
COMMON_KEYS = [

scripts/export_dcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from zeroband.checkpoint import ModelWrapper
1212
from zeroband.utils import get_module_signature
1313
from zeroband.train import Config
14-
from zeroband.utils.logger import get_logger
14+
from zeroband.utils.logging import get_logger
1515
from pydantic_config import parse_argv
1616
from transformers import AutoTokenizer
1717
import math

scripts/skip_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from zeroband.data import get_dataloader
2626

2727
from zeroband.utils.world_info import get_world_info
28-
from zeroband.utils.logger import get_logger
28+
from zeroband.utils.logging import get_logger
2929

3030

3131
def skip_data(config: Config):

src/zeroband/checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
send_tensor_and_state_dict,
3737
)
3838
from distributed_shampoo import DistributedShampoo
39-
from zeroband.utils.logger import get_logger
39+
from zeroband.utils.logging import get_logger
4040
from zeroband.config import CkptConfig
4141
from zeroband.utils.world_info import get_world_info
4242

@@ -151,7 +151,7 @@ def non_error_barrier():
151151
try:
152152
dist.barrier()
153153
except Exception as e:
154-
from zeroband.utils.logger import get_logger
154+
from zeroband.utils.logging import get_logger
155155
get_logger().info(f"Error in data checkpointing barrier: {e}, continuing training")
156156

157157

src/zeroband/comms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import subprocess
55
from torch.distributed.device_mesh import init_device_mesh
66
from zeroband.utils.world_info import get_world_info
7-
from zeroband.utils.logger import get_logger
7+
from zeroband.utils.logging import get_logger
88
import torch.distributed as dist
99
from datetime import timedelta
1010
from typing import List, Tuple, Optional

src/zeroband/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Generator, Optional, List, Dict, TypedDict, Union
44
import functools
55

6-
from zeroband.utils.logger import get_logger
6+
from zeroband.utils.logging import get_logger
77
from zeroband.config import DataConfig
88

99
import torch
@@ -222,7 +222,7 @@ class InterleaveDataset(IterableDataset, Stateful):
222222
The state can be saved and restored. Under the hood we just fast forward the random generator to the current position.
223223
"""
224224

225-
def __init__(self, datasets: List[ParquetDataset], probabilities: List[float], seed: int = 42):
225+
def __init__(self, datasets: List[ParquetDataset], probabilities: Optional[List[float]] = None, seed: int = 42):
226226
assert len(datasets) > 0, "At least one dataset is required"
227227
assert len(datasets) == len(probabilities), "The number of datasets and probabilities must be the same"
228228

src/zeroband/diloco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from zeroband.comms import ElasticDeviceMesh
66
from zeroband.collectives import Compression, all_reduce
77
from zeroband.utils.world_info import get_world_info
8-
from zeroband.utils.logger import get_logger
8+
from zeroband.utils.logging import get_logger
99
from zeroband.config import DilocoConfig
1010
import torch.distributed as dist
1111
from torch.distributed._tensor.api import DTensor

src/zeroband/train.py

Lines changed: 24 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333
from zeroband.utils.activation_ckpt import apply_ac_ckpt
3434
from zeroband.utils.profiler import MemoryProfiler
3535
from zeroband.utils.world_info import get_world_info
36-
from zeroband.utils.logger import get_logger
37-
from zeroband.utils.stopwatch import Stopwatch
36+
from zeroband.utils.logging import get_logger
3837

3938
from transformers import AutoTokenizer
4039
from pydantic_config import parse_argv
@@ -94,11 +93,6 @@ def train(config: Config):
9493
config.ckpt.interval % config.diloco.inner_steps == 0
9594
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"
9695

97-
sw = Stopwatch(config)
98-
sw.start("train()")
99-
100-
# Load tokenizer
101-
sw.start_block()
10296
if config.data.fake and config.name_model == "debugmodel":
10397
tokenizer = FakeTokenizer()
10498
elif config.type_model == "llama2":
@@ -107,10 +101,11 @@ def train(config: Config):
107101
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
108102
else:
109103
raise ValueError(f"Model type {config.type_model} not supported")
110-
sw.end_block("tokenizer loaded")
104+
105+
logger.debug("tokenizer loaded")
111106

112107
with record_function("Get dataloader"):
113-
sw.start_block()
108+
logger.debug("Getting dataloader")
114109
train_dataloader = get_dataloader(
115110
tokenizer=tokenizer,
116111
world_size=world_info.world_size,
@@ -119,16 +114,13 @@ def train(config: Config):
119114
data_config=config.data,
120115
)
121116
train_dataloader_iterator = iter(train_dataloader)
122-
sw.end_block("dataloader loaded")
123117

124118
with record_function("Get model"):
125-
sw.start_block("Constructing model")
119+
logger.debug("Constructing model")
126120
model, model_config = get_model(
127121
config,
128122
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
129123
)
130-
sw.end_block("Constructed model")
131-
132124

133125
gpu_peak_flops = get_peak_flops(torch.cuda.get_device_name(torch.device("cuda")))
134126
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
@@ -142,7 +134,6 @@ def train(config: Config):
142134
)
143135

144136
with record_function("Shard model"):
145-
sw.start_block("Sharding model")
146137
if config.train.ac_ckpt:
147138
num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt
148139
apply_ac_ckpt(model, num)
@@ -178,11 +169,10 @@ def train(config: Config):
178169
reshard_after_forward=config.train.reshard_after_forward,
179170
offload_policy=offload_policy,
180171
)
181-
sw.end_block()
172+
logger.debug("model fsdped")
182173

183174
# Setup optimizers
184175
with record_function("Set up Optimizers"):
185-
sw.start_block()
186176
inner_optimizer = get_optimizer(config, model.parameters())
187177

188178
diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None
@@ -209,7 +199,7 @@ def train(config: Config):
209199
diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore
210200
)
211201

212-
sw.end_block("Optimizers set up")
202+
logger.debug("Optimizers set up.")
213203

214204
if world_info.rank == 0:
215205
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
@@ -223,15 +213,12 @@ def train(config: Config):
223213

224214
with record_function("Compile model"):
225215
if config.train.torch_compile:
226-
sw.start_block()
227216
# we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY
228217
model = torch.compile(model) if not TYPE_CHECKING else model
229-
sw.end_block("model compiled")
218+
logger.debug("model compiled")
230219

231220
with record_function("Resume checkpoint"):
232221
if config.ckpt.resume is not None:
233-
sw.start_block("Resuming checkpoint")
234-
235222
# all is inplace
236223
ckpt_manager.load(
237224
resume_ckpt_path=config.ckpt.resume,
@@ -242,8 +229,6 @@ def train(config: Config):
242229
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume"
243230
)
244231

245-
sw.end_block("Checkpoint resumed")
246-
247232
if config.train.memory_profiler is not None:
248233
memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir)
249234

@@ -254,7 +239,7 @@ def train(config: Config):
254239
num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1
255240
perf_counter = PerfCounter(window_size=10)
256241

257-
logger.debug("Finished setup in %f seconds", sw.elapsed())
242+
logger.info("starting training")
258243

259244
need_live_recovery = config.ckpt.live_recovery_rank_src is not None
260245
while True:
@@ -312,21 +297,19 @@ def train(config: Config):
312297

313298
for inner_step in range(num_inner_steps):
314299
logger.debug("Starting inner step.")
315-
sw.start("inner_step")
316300

317301
loss_batch = 0
318302
z_loss_batch = 0
319303

320-
sw.start_block("Running grad acc steps")
321304
for grad_acc_step in range(gradient_accumulation_steps):
322-
sw.start("grad_acc_step")
305+
logger.debug("Starting gradient accumulation step.")
323306

324307
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
325308
# no sync if we are accumulating gradients
326309
model.set_requires_gradient_sync(not is_accumulating)
327310

328311
with record_function("Load batch"):
329-
sw.start_block()
312+
logger.debug("Loading batch")
330313
# TODO/NOTE: We could overlap sending the batch with communication
331314
# although to be honest the perf impact is minimal
332315
batch = next(train_dataloader_iterator)
@@ -338,17 +321,15 @@ def train(config: Config):
338321
else:
339322
seqlens = None
340323
block_mask = None
341-
sw.end_block("batch loaded")
342324

343325
with record_function("Run model"):
344-
sw.start_block()
326+
logger.debug("Running forward()")
345327
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
346-
flatten_logits = logits.reshape(-1, logits.size(-1)) # b seq vocab -> (b * seq) vocab
347-
flatten_labels = labels.reshape(-1) # b seq -> (b * seq)
348-
sw.end_block("Ran forward()")
328+
flatten_logits = logits.reshape(-1, logits.size(-1)) # b seq vocab -> (b seq) vocab
329+
flatten_labels = labels.reshape(-1) # b seq -> (b seq)
349330

350331
with record_function("Loss calculation"):
351-
sw.start_block()
332+
logger.debug("Computing loss")
352333
ce_loss, z_loss = compute_cross_entropy_loss(
353334
flatten_logits,
354335
flatten_labels,
@@ -368,28 +349,22 @@ def train(config: Config):
368349
loss = ce_loss + z_loss
369350
else:
370351
loss = ce_loss / gradient_accumulation_steps
371-
sw.end_block("Loss computed")
372352

373353
with record_function("Backward"):
374-
sw.start_block()
354+
logger.debug("Running backward()")
375355
loss.backward()
376-
sw.end_block("Ran backward()")
377356

378357
with record_function("Clone loss"):
379-
# No need to time, takes 0 seconds
358+
logger.debug("Cloning loss")
380359
if config.optim.z_loss:
381360
assert z_loss is not None
382361
loss_batch += ce_loss.detach().clone()
383362
z_loss_batch += z_loss.detach().clone()
384363
else:
385364
loss_batch += loss.detach().clone()
386365

387-
elapsed = sw.stop("grad_acc_step")
388-
logger.debug(f"Grad acc step {grad_acc_step} completed in {elapsed:.2f} seconds")
389-
sw.end_block("Finished grad acc steps")
390-
391-
with record_function("Loss allreduce"):
392-
sw.start_block()
366+
with record_function("Inner allreduce"):
367+
logger.debug("loss allreduce()")
393368
# Launch both allreduces at the same time to hide latency
394369
loss_allreduce = dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True)
395370
if config.optim.z_loss:
@@ -400,22 +375,18 @@ def train(config: Config):
400375
if config.optim.z_loss:
401376
assert isinstance(z_loss_allreduce, torch.distributed.Work)
402377
z_loss_allreduce.wait()
403-
sw.end_block("loss allreduced")
404378

405379
with record_function("Clip grad"):
406-
sw.start_block()
407-
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).full_tensor() # type: ignore (is a dtensor)
408-
sw.end_block("Clipped grad")
380+
logger.debug("clipping grad")
381+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).full_tensor()
382+
# full tensor needed because grad_norm is a DTensor
409383

410384
with record_function("Optimizer step"):
411-
sw.start_block()
385+
logger.debug("inner optimizer step()")
412386
inner_optimizer.step()
413387
scheduler.step()
414-
sw.end_block("Inner optimizer step()")
415-
416-
sw.start_block()
388+
logger.debug("inner optimizer zero_grad()")
417389
inner_optimizer.zero_grad()
418-
sw.end_block("inner optimizer zero_grad()")
419390

420391
# logging
421392
training_progress.step += 1
@@ -472,9 +443,6 @@ def train(config: Config):
472443
if config.train.memory_profiler is not None:
473444
memory_profiler.step()
474445

475-
elapsed = sw.stop("inner_step")
476-
logger.debug(f"Inner step {inner_step} completed in {elapsed:.2f} seconds")
477-
478446
if config.diloco is not None:
479447
assert diloco is not None
480448
if world_info.rank == 0 and config.monitor is not None:

src/zeroband/utils/activation_ckpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
44

5-
from zeroband.utils.logger import get_logger
5+
from zeroband.utils.logging import get_logger
66

77

88
def apply_ac_ckpt(model: Transformer, num: int):
@@ -21,4 +21,4 @@ def apply_ac_ckpt(model: Transformer, num: int):
2121
model.layers.register_module(layer_id, transformer_block)
2222
layers_ckpt += 1
2323

24-
logger.debug(f"Applied activation checkpointing to {layers_ckpt} layers")
24+
logger.info(f"Applied activation checkpointing to {layers_ckpt} layers")

0 commit comments

Comments
 (0)