@@ -232,7 +232,7 @@ def setup_fsdp2(model: torch.nn.Module) -> torch.nn.Module:
232232 return model
233233
234234
235- def setup_compilation (model , backend : str ):
235+ def setup_compilation (model , backend : str , thunder_cache : str | None = None ):
236236 # TODO from thunder.executors.transformer_engineex import transformer_engine_ex
237237 """Apply compilation settings to the model."""
238238 if backend in ("thunder" , "inductor" ):
@@ -262,14 +262,15 @@ def setup_compilation(model, backend: str):
262262
263263 if "jit" in backend :
264264 logger .info ("Using thunder.jit" )
265- model = thunder .jit (model , transforms = xforms , executors = executors )
265+ model = thunder .jit (model , transforms = xforms , executors = executors , cache = thunder_cache )
266266 else :
267267 logger .info ("Using ThunderFX" )
268268 from thunder .dynamo import thunderfx
269269
270270 # TODO get parameters out from thunderfx CompiledObject
271- compiled_object = thunderfx (model , transforms = xforms , executors = executors )
271+ compiled_object = thunderfx (model , transforms = xforms , executors = executors , cache = thunder_cache )
272272 model = compiled_object ._func
273+ model ._thunder_backend = compiled_object ._backend
273274
274275 return model
275276
@@ -302,6 +303,12 @@ def parse_args():
302303 type = str .lower ,
303304 choices = ["eager" , "inductor" , "thunder" , "thunder+jit" ],
304305 )
306+ parser .add_argument (
307+ "--thunder-cache" ,
308+ type = str ,
309+ default = None ,
310+ help = "Cache option: no caching, same input, constant values, symbolic values. See `cache` argument of `thunder.jit` for more details." ,
311+ )
305312 parser .add_argument ("--verbose" , action = "store_true" , help = "Enable verbose output including model wrapping details" )
306313 parser .add_argument ("--trust-remote-code" , action = "store_true" )
307314 parser .add_argument (
@@ -468,7 +475,7 @@ def main(args: argparse.Namespace):
468475 # Apply compilation if needed
469476 if args .compile != "eager" :
470477 logger .info (f"Applying compilation: { args .compile } to model" )
471- model = setup_compilation (model , args .compile )
478+ model = setup_compilation (model , args .compile , thunder_cache = args . thunder_cache )
472479 logger .info ("Compilation applied to model" )
473480
474481 # Verify only LoRA parameters are trainable
@@ -583,6 +590,21 @@ def main(args: argparse.Namespace):
583590
584591 # Print training summary
585592 total_time = time .time () - start_ts
593+
594+ # Compute Thunder recompilations (cache misses minus initial compile) if applicable
595+ num_recompilations = 0
596+ if "thunder" in args .compile :
597+ import thunder
598+
599+ if "jit" in args .compile :
600+ num_recompilations = thunder .cache_misses (model ) - 1
601+ else :
602+ total_misses = 0
603+ for subgraph_info in model ._thunder_backend .subgraph_infos :
604+ for thunder_fn in subgraph_info .thunder_compiled_fns :
605+ total_misses += thunder .cache_misses (thunder_fn ) - 1
606+ num_recompilations = total_misses
607+
586608 print_training_summary (
587609 args ,
588610 total_time ,
@@ -593,6 +615,7 @@ def main(args: argparse.Namespace):
593615 batches_processed ,
594616 total_tokens_processed ,
595617 WORLD_SIZE ,
618+ num_recompilations ,
596619 )
597620
598621 # Clean up distributed environment if needed
@@ -610,6 +633,7 @@ def print_training_summary(
610633 batches_processed : int ,
611634 total_tokens_processed : int ,
612635 WORLD_SIZE : int ,
636+ num_recompilations : int ,
613637) -> None :
614638 """Print a comprehensive summary of the training run.
615639
@@ -650,6 +674,8 @@ def print_training_summary(
650674 logger .info (f"Maximum allocated memory: { max_allocated_memory / 1024 ** 3 :.2f} GB" )
651675 logger .info (f"Total tokens processed: { total_tokens :,} " )
652676 logger .info (f"Total iterations: { args .max_steps } " )
677+ if "thunder" in args .compile :
678+ logger .info (f"Thunder module recompilations excluding initial compile: { num_recompilations } " )
653679
654680 # Verify batch processing across all ranks
655681 if WORLD_SIZE > 1 :
0 commit comments