3333from zeroband .utils .activation_ckpt import apply_ac_ckpt
3434from zeroband .utils .profiler import MemoryProfiler
3535from zeroband .utils .world_info import get_world_info
36- from zeroband .utils .logger import get_logger
37- from zeroband .utils .stopwatch import Stopwatch
36+ from zeroband .utils .logging import get_logger
3837
3938from transformers import AutoTokenizer
4039from pydantic_config import parse_argv
@@ -94,11 +93,6 @@ def train(config: Config):
9493 config .ckpt .interval % config .diloco .inner_steps == 0
9594 ), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"
9695
97- sw = Stopwatch (config )
98- sw .start ("train()" )
99-
100- # Load tokenizer
101- sw .start_block ()
10296 if config .data .fake and config .name_model == "debugmodel" :
10397 tokenizer = FakeTokenizer ()
10498 elif config .type_model == "llama2" :
@@ -107,10 +101,11 @@ def train(config: Config):
107101 tokenizer = AutoTokenizer .from_pretrained ("meta-llama/Meta-Llama-3-8B" , use_fast = True )
108102 else :
109103 raise ValueError (f"Model type { config .type_model } not supported" )
110- sw .end_block ("tokenizer loaded" )
104+
105+ logger .debug ("tokenizer loaded" )
111106
112107 with record_function ("Get dataloader" ):
113- sw . start_block ( )
108+ logger . debug ( "Getting dataloader" )
114109 train_dataloader = get_dataloader (
115110 tokenizer = tokenizer ,
116111 world_size = world_info .world_size ,
@@ -119,16 +114,13 @@ def train(config: Config):
119114 data_config = config .data ,
120115 )
121116 train_dataloader_iterator = iter (train_dataloader )
122- sw .end_block ("dataloader loaded" )
123117
124118 with record_function ("Get model" ):
125- sw . start_block ("Constructing model" )
119+ logger . debug ("Constructing model" )
126120 model , model_config = get_model (
127121 config ,
128122 vocab_size = len (tokenizer ) if config .name_model != "debugmodel" or not config .data .fake else TEST_VOCAB_SIZE ,
129123 )
130- sw .end_block ("Constructed model" )
131-
132124
133125 gpu_peak_flops = get_peak_flops (torch .cuda .get_device_name (torch .device ("cuda" )))
134126 logger .info (f"Peak FLOPS used for computing MFU: { gpu_peak_flops :.3e} " )
@@ -142,7 +134,6 @@ def train(config: Config):
142134 )
143135
144136 with record_function ("Shard model" ):
145- sw .start_block ("Sharding model" )
146137 if config .train .ac_ckpt :
147138 num = 1 if isinstance (config .train .ac_ckpt , bool ) else config .train .ac_ckpt
148139 apply_ac_ckpt (model , num )
@@ -178,11 +169,10 @@ def train(config: Config):
178169 reshard_after_forward = config .train .reshard_after_forward ,
179170 offload_policy = offload_policy ,
180171 )
181- sw . end_block ( )
172+ logger . debug ( "model fsdped" )
182173
183174 # Setup optimizers
184175 with record_function ("Set up Optimizers" ):
185- sw .start_block ()
186176 inner_optimizer = get_optimizer (config , model .parameters ())
187177
188178 diloco = Diloco (config .diloco , model , elastic_device_mesh ) if config .diloco is not None else None
@@ -209,7 +199,7 @@ def train(config: Config):
209199 diloco_offloaded_param_list = diloco .param_list_cpu if config .diloco is not None else None , # type: ignore
210200 )
211201
212- sw . end_block ("Optimizers set up" )
202+ logger . debug ("Optimizers set up. " )
213203
214204 if world_info .rank == 0 :
215205 logger_cls = WandbMetricLogger if config .metric_logger_type == "wandb" else DummyMetricLogger
@@ -223,15 +213,12 @@ def train(config: Config):
223213
224214 with record_function ("Compile model" ):
225215 if config .train .torch_compile :
226- sw .start_block ()
227216 # we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY
228217 model = torch .compile (model ) if not TYPE_CHECKING else model
229- sw . end_block ("model compiled" )
218+ logger . debug ("model compiled" )
230219
231220 with record_function ("Resume checkpoint" ):
232221 if config .ckpt .resume is not None :
233- sw .start_block ("Resuming checkpoint" )
234-
235222 # all is inplace
236223 ckpt_manager .load (
237224 resume_ckpt_path = config .ckpt .resume ,
@@ -242,8 +229,6 @@ def train(config: Config):
242229 config , model , inner_optimizer , diloco , metric_logger , step = training_progress .step , id = "resume"
243230 )
244231
245- sw .end_block ("Checkpoint resumed" )
246-
247232 if config .train .memory_profiler is not None :
248233 memory_profiler = MemoryProfiler (config .train .memory_profiler .freq , config .train .memory_profiler .snapshot_dir )
249234
@@ -254,7 +239,7 @@ def train(config: Config):
254239 num_inner_steps = config .diloco .inner_steps if config .diloco is not None else 1
255240 perf_counter = PerfCounter (window_size = 10 )
256241
257- logger .debug ( "Finished setup in %f seconds" , sw . elapsed () )
242+ logger .info ( "starting training" )
258243
259244 need_live_recovery = config .ckpt .live_recovery_rank_src is not None
260245 while True :
@@ -312,21 +297,19 @@ def train(config: Config):
312297
313298 for inner_step in range (num_inner_steps ):
314299 logger .debug ("Starting inner step." )
315- sw .start ("inner_step" )
316300
317301 loss_batch = 0
318302 z_loss_batch = 0
319303
320- sw .start_block ("Running grad acc steps" )
321304 for grad_acc_step in range (gradient_accumulation_steps ):
322- sw . start ( "grad_acc_step " )
305+ logger . debug ( "Starting gradient accumulation step. " )
323306
324307 is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
325308 # no sync if we are accumulating gradients
326309 model .set_requires_gradient_sync (not is_accumulating )
327310
328311 with record_function ("Load batch" ):
329- sw . start_block ( )
312+ logger . debug ( "Loading batch" )
330313 # TODO/NOTE: We could overlap sending the batch with communication
331314 # although to be honest the perf impact is minimal
332315 batch = next (train_dataloader_iterator )
@@ -338,17 +321,15 @@ def train(config: Config):
338321 else :
339322 seqlens = None
340323 block_mask = None
341- sw .end_block ("batch loaded" )
342324
343325 with record_function ("Run model" ):
344- sw . start_block ( )
326+ logger . debug ( "Running forward()" )
345327 logits = model (tokens = input_ids , block_mask = block_mask ).contiguous ()
346- flatten_logits = logits .reshape (- 1 , logits .size (- 1 )) # b seq vocab -> (b * seq) vocab
347- flatten_labels = labels .reshape (- 1 ) # b seq -> (b * seq)
348- sw .end_block ("Ran forward()" )
328+ flatten_logits = logits .reshape (- 1 , logits .size (- 1 )) # b seq vocab -> (b seq) vocab
329+ flatten_labels = labels .reshape (- 1 ) # b seq -> (b seq)
349330
350331 with record_function ("Loss calculation" ):
351- sw . start_block ( )
332+ logger . debug ( "Computing loss" )
352333 ce_loss , z_loss = compute_cross_entropy_loss (
353334 flatten_logits ,
354335 flatten_labels ,
@@ -368,28 +349,22 @@ def train(config: Config):
368349 loss = ce_loss + z_loss
369350 else :
370351 loss = ce_loss / gradient_accumulation_steps
371- sw .end_block ("Loss computed" )
372352
373353 with record_function ("Backward" ):
374- sw . start_block ( )
354+ logger . debug ( "Running backward()" )
375355 loss .backward ()
376- sw .end_block ("Ran backward()" )
377356
378357 with record_function ("Clone loss" ):
379- # No need to time, takes 0 seconds
358+ logger . debug ( "Cloning loss" )
380359 if config .optim .z_loss :
381360 assert z_loss is not None
382361 loss_batch += ce_loss .detach ().clone ()
383362 z_loss_batch += z_loss .detach ().clone ()
384363 else :
385364 loss_batch += loss .detach ().clone ()
386365
387- elapsed = sw .stop ("grad_acc_step" )
388- logger .debug (f"Grad acc step { grad_acc_step } completed in { elapsed :.2f} seconds" )
389- sw .end_block ("Finished grad acc steps" )
390-
391- with record_function ("Loss allreduce" ):
392- sw .start_block ()
366+ with record_function ("Inner allreduce" ):
367+ logger .debug ("loss allreduce()" )
393368 # Launch both allreduces at the same time to hide latency
394369 loss_allreduce = dist .all_reduce (tensor = loss_batch , op = dist .ReduceOp .AVG , group = elastic_device_mesh .local_pg , async_op = True )
395370 if config .optim .z_loss :
@@ -400,22 +375,18 @@ def train(config: Config):
400375 if config .optim .z_loss :
401376 assert isinstance (z_loss_allreduce , torch .distributed .Work )
402377 z_loss_allreduce .wait ()
403- sw .end_block ("loss allreduced" )
404378
405379 with record_function ("Clip grad" ):
406- sw . start_block ( )
407- grad_norm = torch .nn .utils .clip_grad_norm_ (model .parameters (), 1.0 ).full_tensor () # type: ignore (is a dtensor)
408- sw . end_block ( "Clipped grad" )
380+ logger . debug ( "clipping grad" )
381+ grad_norm = torch .nn .utils .clip_grad_norm_ (model .parameters (), 1.0 ).full_tensor ()
382+ # full tensor needed because grad_norm is a DTensor
409383
410384 with record_function ("Optimizer step" ):
411- sw . start_block ( )
385+ logger . debug ( "inner optimizer step()" )
412386 inner_optimizer .step ()
413387 scheduler .step ()
414- sw .end_block ("Inner optimizer step()" )
415-
416- sw .start_block ()
388+ logger .debug ("inner optimizer zero_grad()" )
417389 inner_optimizer .zero_grad ()
418- sw .end_block ("inner optimizer zero_grad()" )
419390
420391 # logging
421392 training_progress .step += 1
@@ -472,9 +443,6 @@ def train(config: Config):
472443 if config .train .memory_profiler is not None :
473444 memory_profiler .step ()
474445
475- elapsed = sw .stop ("inner_step" )
476- logger .debug (f"Inner step { inner_step } completed in { elapsed :.2f} seconds" )
477-
478446 if config .diloco is not None :
479447 assert diloco is not None
480448 if world_info .rank == 0 and config .monitor is not None :
0 commit comments