Skip to content

Commit c101dc2

Browse files
committed
Add --thunder-cache option and count recompiles
1 parent fb989d4 commit c101dc2

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

thunder/benchmarks/benchmark_inference.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ class InferenceMetrics:
246246
prefill_time_ms: float = 0.0
247247
decode_time_ms: float = 0.0
248248

249+
# Compilation metrics
250+
num_recompilations: int = 0
251+
249252
# Per-iteration metrics for variance analysis
250253
iteration_times: list[float] = field(default_factory=list)
251254
ttft_times: list[float] = field(default_factory=list)
@@ -572,6 +575,8 @@ def run_benchmark(self) -> InferenceMetrics:
572575

573576
self._calculate_aggregate_metrics(all_metrics)
574577

578+
self.metrics.num_recompilations = self._get_recompilation_count()
579+
575580
if torch.cuda.is_available():
576581
self.metrics.memory_used_gb = torch.cuda.memory_allocated() / 1e9
577582
self.metrics.peak_memory_gb = torch.cuda.max_memory_allocated() / 1e9
@@ -582,6 +587,20 @@ def run_benchmark(self) -> InferenceMetrics:
582587

583588
return self.metrics
584589

590+
def _get_recompilation_count(self) -> int:
591+
"""Count Thunder cache misses (recompiles of the same jitted module), excluding the initial compile and Dynamo graph rebuilds."""
592+
if self.config.mode == "thunder":
593+
backend = self.model._backend
594+
total_misses = 0
595+
for subgraph_info in backend.subgraph_infos:
596+
for thunder_fn in subgraph_info.thunder_compiled_fns:
597+
total_misses += thunder.cache_misses(thunder_fn) - 1
598+
return total_misses
599+
elif self.config.mode == "thunderjit":
600+
return thunder.cache_misses(self.model) - 1
601+
else:
602+
return 0
603+
585604
def _calculate_aggregate_metrics(self, all_metrics: list[dict[str, Any]]):
586605
"""Calculate aggregate metrics from individual iterations"""
587606
# Average throughput
@@ -638,6 +657,13 @@ def print_results(self):
638657
print(f" Current Memory: {self.metrics.memory_used_gb:.2f} GB")
639658
print(f" Peak Memory: {self.metrics.peak_memory_gb:.2f} GB")
640659

660+
if self.config.mode in ("thunder", "thunderjit"):
661+
print("\nCompilation Metrics:")
662+
print(
663+
" Number of Thunder module recompilations excluding initial compile: "
664+
f"{self.metrics.num_recompilations}"
665+
)
666+
641667
if len(self.metrics.iteration_times) > 1:
642668
print("\nVariance Analysis:")
643669
print(f" Throughput Std Dev: {statistics.stdev(self.metrics.iteration_times):.2f} ms")
@@ -659,6 +685,7 @@ def save_results(self, filename: str):
659685
"total_time_ms": self.metrics.total_time_ms,
660686
"memory_used_gb": self.metrics.memory_used_gb,
661687
"peak_memory_gb": self.metrics.peak_memory_gb,
688+
"num_recompilations": self.metrics.num_recompilations,
662689
},
663690
"detailed_metrics": {
664691
"iteration_times": self.metrics.iteration_times,

thunder/benchmarks/benchmark_litgpt.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def __init__(
315315
fp8_shard_intermediate_activation: bool = False,
316316
use_sdpa: bool = False,
317317
use_hf: bool = False,
318+
thunder_cache: str | None = None,
318319
):
319320
seed = 1337
320321
torch.manual_seed(seed)
@@ -350,6 +351,7 @@ def __init__(
350351

351352
self.use_sdpa = use_sdpa
352353
self.use_hf = use_hf
354+
self.thunder_cache = thunder_cache
353355

354356
if self.use_sdpa and sdpa_available and self.compile not in ["eager", "inductor"]:
355357
warnings.warn(
@@ -692,7 +694,7 @@ def setup_compile(self, model):
692694
transforms.insert(0, TransformerEngineTransform())
693695

694696
if "jit" in self.compile:
695-
model = thunder.jit(model, executors=executors, transforms=transforms, **jit_options)
697+
model = thunder.jit(model, executors=executors, transforms=transforms, cache=self.thunder_cache)
696698

697699
else:
698700
if self.distributed_mode == "fsdp2":
@@ -862,6 +864,7 @@ def train(self):
862864
self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iters)
863865
self.perf_metrics["saved_for_backward_tensor_size_mib"] = saved_tensors_size_in_mib
864866
self.perf_metrics["saved_for_backward_number_of_tensors"] = saved_tensors_len
867+
self.perf_metrics["num_recompilations"] = self.compute_num_recompilations()
865868

866869
def add_perf_metrics(self):
867870
if self.throughput:
@@ -893,6 +896,22 @@ def add_model_info_to_metrics(self):
893896
self.perf_metrics["Sharding Size"] = None
894897
self.perf_metrics["compiler"] = self.compile
895898

899+
def compute_num_recompilations(self) -> int:
900+
import thunder
901+
902+
if "thunder" not in self.compile:
903+
return 0
904+
905+
if "jit" in self.compile:
906+
return thunder.cache_misses(self.model) - 1
907+
908+
# Compiled by ThunderFX
909+
total_misses = 0
910+
for info in self.backend.subgraph_infos:
911+
for thunder_fn in info.thunder_compiled_fns:
912+
total_misses += thunder.cache_misses(thunder_fn) - 1
913+
return total_misses
914+
896915

897916
class DummyDataset(IterableDataset):
898917
def __init__(self, max_seq_length: int, dynamic: bool):
@@ -979,6 +998,11 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
979998
print(
980999
f"Saved for backward number of tensors: {benchmark.perf_metrics['saved_for_backward_number_of_tensors']}"
9811000
)
1001+
if "thunder" in benchmark.compile:
1002+
print(
1003+
"Thunder module recompilations excluding initial compile: "
1004+
f"{benchmark.perf_metrics.get('num_recompilations', 0)}"
1005+
)
9821006

9831007
tokens_per_sec = benchmark.perf_metrics.get("tokens_per_sec")
9841008
if tokens_per_sec:

thunder/benchmarks/benchmark_peft.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def setup_fsdp2(model: torch.nn.Module) -> torch.nn.Module:
232232
return model
233233

234234

235-
def setup_compilation(model, backend: str):
235+
def setup_compilation(model, backend: str, thunder_cache: str | None = None):
236236
# TODO from thunder.executors.transformer_engineex import transformer_engine_ex
237237
"""Apply compilation settings to the model."""
238238
if backend in ("thunder", "inductor"):
@@ -262,14 +262,15 @@ def setup_compilation(model, backend: str):
262262

263263
if "jit" in backend:
264264
logger.info("Using thunder.jit")
265-
model = thunder.jit(model, transforms=xforms, executors=executors)
265+
model = thunder.jit(model, transforms=xforms, executors=executors, cache=thunder_cache)
266266
else:
267267
logger.info("Using ThunderFX")
268268
from thunder.dynamo import thunderfx
269269

270270
# TODO get parameters out from thunderfx CompiledObject
271-
compiled_object = thunderfx(model, transforms=xforms, executors=executors)
271+
compiled_object = thunderfx(model, transforms=xforms, executors=executors, cache=thunder_cache)
272272
model = compiled_object._func
273+
model._thunder_backend = compiled_object._backend
273274

274275
return model
275276

@@ -302,6 +303,12 @@ def parse_args():
302303
type=str.lower,
303304
choices=["eager", "inductor", "thunder", "thunder+jit"],
304305
)
306+
parser.add_argument(
307+
"--thunder-cache",
308+
type=str,
309+
default=None,
310+
help="Cache option: no caching, same input, constant values, symbolic values. See `cache` argument of `thunder.jit` for more details.",
311+
)
305312
parser.add_argument("--verbose", action="store_true", help="Enable verbose output including model wrapping details")
306313
parser.add_argument("--trust-remote-code", action="store_true")
307314
parser.add_argument(
@@ -468,7 +475,7 @@ def main(args: argparse.Namespace):
468475
# Apply compilation if needed
469476
if args.compile != "eager":
470477
logger.info(f"Applying compilation: {args.compile} to model")
471-
model = setup_compilation(model, args.compile)
478+
model = setup_compilation(model, args.compile, thunder_cache=args.thunder_cache)
472479
logger.info("Compilation applied to model")
473480

474481
# Verify only LoRA parameters are trainable
@@ -583,6 +590,21 @@ def main(args: argparse.Namespace):
583590

584591
# Print training summary
585592
total_time = time.time() - start_ts
593+
594+
# Compute Thunder recompilations (cache misses minus initial compile) if applicable
595+
num_recompilations = 0
596+
if "thunder" in args.compile:
597+
import thunder
598+
599+
if "jit" in args.compile:
600+
num_recompilations = thunder.cache_misses(model) - 1
601+
else:
602+
total_misses = 0
603+
for subgraph_info in model._thunder_backend.subgraph_infos:
604+
for thunder_fn in subgraph_info.thunder_compiled_fns:
605+
total_misses += thunder.cache_misses(thunder_fn) - 1
606+
num_recompilations = total_misses
607+
586608
print_training_summary(
587609
args,
588610
total_time,
@@ -593,6 +615,7 @@ def main(args: argparse.Namespace):
593615
batches_processed,
594616
total_tokens_processed,
595617
WORLD_SIZE,
618+
num_recompilations,
596619
)
597620

598621
# Clean up distributed environment if needed
@@ -610,6 +633,7 @@ def print_training_summary(
610633
batches_processed: int,
611634
total_tokens_processed: int,
612635
WORLD_SIZE: int,
636+
num_recompilations: int,
613637
) -> None:
614638
"""Print a comprehensive summary of the training run.
615639
@@ -650,6 +674,8 @@ def print_training_summary(
650674
logger.info(f"Maximum allocated memory: {max_allocated_memory / 1024**3:.2f} GB")
651675
logger.info(f"Total tokens processed: {total_tokens:,}")
652676
logger.info(f"Total iterations: {args.max_steps}")
677+
if "thunder" in args.compile:
678+
logger.info(f"Thunder module recompilations excluding initial compile: {num_recompilations}")
653679

654680
# Verify batch processing across all ranks
655681
if WORLD_SIZE > 1:

0 commit comments

Comments
 (0)