13
13
import os
14
14
from copy import deepcopy
15
15
from pathlib import Path
16
+ from typing import Dict , Tuple
16
17
17
18
import torch
18
19
34
35
create_model_and_input_data ,
35
36
)
36
37
38
+ # -----------------------------------------------------------------------------
39
+ # Baseline caching
40
+ #
41
+ # ``_BASELINE_CACHE`` maps a unique key constructed using _make_cache_key(config) -> (model_type, m, k, n, high_precision_dtype, device, torch_compile_mode) to a tuple
42
+ # ``(eager_baseline_time, compile_baseline_time)``. See ``_make_cache_key`` for the key
43
+ # construction. Users should not access this cache directly; it is
44
+ # internal to this module.
45
+ # Eg: (linear, 1024, 1024, 1024, torch.bfloat16, cuda, default) -> (95.00, 56.00)
46
+ # The cache is used to store the baseline inference time for a given configuration, which is further used to calculate speedup metrics.
47
+ # This helps in removing multiple baseline calculations, which in turn helps in reducing the benchmarking time.
48
+ # -----------------------------------------------------------------------------
49
+
50
+ _BASELINE_CACHE : Dict [Tuple , Tuple [float , float ]] = {}
51
+
52
+
53
+ def _make_cache_key (config : BenchmarkConfig ) -> Tuple :
54
+ """Create a key for caching based on benchmark configuration.
55
+
56
+ Parameters that affect baseline performance are included:
57
+
58
+ * model type (e.g. ``linear`` or ``transformer_block``)
59
+ * shape dimensions (m, k, n)
60
+ * high precision dtype (bf16, fp16, etc.)
61
+ * device (cuda, cpu, mps)
62
+ * compile settings (whether compile is enabled and compile mode)
63
+
64
+ Sparsity and quantization settings are deliberately excluded
65
+ because the baseline (non‑quantized, non‑sparse) performance is
66
+ independent of those attributes.
67
+ """
68
+ return (
69
+ config .model_type ,
70
+ config .m ,
71
+ config .k ,
72
+ config .n ,
73
+ config .high_precision_dtype ,
74
+ config .device ,
75
+ config .torch_compile_mode ,
76
+ )
77
+
37
78
38
79
def run (config : BenchmarkConfig ) -> BenchmarkResult :
39
- """Run inference benchmarks"""
80
+ """
81
+ Run inference benchmarks.
82
+
83
+ The function first checks if a baseline for the given configuration
84
+ already exists in the internal cache. If not, it measures the baseline
85
+ inference time and stores the result. When the baseline is cached,
86
+ the function reuses the cached baselines to calculate speedup metrics.
87
+
88
+ Args:
89
+ config (BenchmarkConfig): Benchmark configuration.
90
+
91
+ Returns:
92
+ BenchmarkResult: Result of the benchmark.
93
+ """
40
94
try :
41
95
clean_caches () # Clean caches
42
96
43
97
# Create output directory if it doesn't exist
44
98
Path (config .output_dir ).mkdir (parents = True , exist_ok = True )
45
99
100
+ # Prepare result container
101
+ result = BenchmarkResult (config = config )
102
+
103
+ # Create model and input data
46
104
base_model , input_data = create_model_and_input_data (
47
105
config .model_type ,
48
106
config .m ,
@@ -51,28 +109,47 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
51
109
high_precision_dtype = config .high_precision_dtype ,
52
110
device = config .device ,
53
111
)
54
- # Copy base model for quantizing
55
- m_copy = deepcopy (base_model )
56
112
57
- # Run benchmarks
58
- result = BenchmarkResult ( config = config )
113
+ # Generate a cache key for the current configuration
114
+ cache_key = _make_cache_key ( config )
59
115
60
- # Store result in model for memory profiling
61
- base_model ._benchmark_result = result
62
-
63
- # Run baseline benchmarking
64
- base_model = base_model .eval ().to (config .device )
65
- if config .use_torch_compile :
66
- print ("Compiling baseline model...." )
67
- base_model = torch .compile (
68
- base_model , mode = config .torch_compile_mode , fullgraph = True
116
+ # Check if the baseline for this configuration has been computed
117
+ if cache_key not in _BASELINE_CACHE :
118
+ # Switch model to eval and move to device
119
+ m_copy = deepcopy (base_model )
120
+ m_copy = m_copy .eval ().to (config .device )
121
+ print ("Benchmarking eager baseline inference....." )
122
+ eager_baseline_time = model_inference_time_in_ms (
123
+ model = m_copy , input_data = input_data
69
124
)
70
- # Benchmark time to run an inference call for baseline model
71
- print ("Benchmarking baseline inference....." )
72
- result .baseline_inference_time_in_ms = model_inference_time_in_ms (
73
- model = base_model , input_data = input_data
74
- )
75
125
126
+ print ("Benchmarking compile baseline inference....." )
127
+ m_copy = torch .compile (
128
+ m_copy , mode = config .torch_compile_mode , fullgraph = True
129
+ )
130
+ compile_baseline_time = model_inference_time_in_ms (
131
+ model = m_copy , input_data = input_data
132
+ )
133
+
134
+ # Store uncompiled model, input and baseline time
135
+ _BASELINE_CACHE [cache_key ] = (eager_baseline_time , compile_baseline_time )
136
+
137
+ result .baseline_model_eager_inference_time_in_ms = eager_baseline_time
138
+ result .baseline_model_compiled_inference_time_in_ms = compile_baseline_time
139
+ else :
140
+ # Retrieve cached values
141
+ cached_eager_time , cached_compile_time = _BASELINE_CACHE [cache_key ]
142
+ result .baseline_model_eager_inference_time_in_ms = cached_eager_time
143
+ result .baseline_model_compiled_inference_time_in_ms = cached_compile_time
144
+
145
+ # At this point, ``base_model`` is an uncompiled model ready for quantization,
146
+ # and ``input_data`` is the corresponding input tensor. The baseline time
147
+ # has been stored in ``result.baseline_inference_time_in_ms``.
148
+
149
+ # Copy base model for quantizing/sparsifying
150
+ m_copy = deepcopy (base_model )
151
+
152
+ # Determine quantization/sparsity configuration
76
153
ao_base_config = string_to_config (
77
154
config .quantization ,
78
155
config .sparsity ,
@@ -101,24 +178,39 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
101
178
m_copy = m_copy .eval ().to (config .device )
102
179
quantize_ (m_copy , ao_base_config )
103
180
104
- if config .use_torch_compile :
105
- print ("Compiling quantized model...." )
106
- m_copy = torch .compile (
107
- m_copy , mode = config .torch_compile_mode , fullgraph = True
108
- )
109
-
110
181
# Store result in model for memory profiling
111
182
m_copy ._benchmark_result = result
112
183
113
- # Benchmark time to run an inference call for quantized model
114
- print ("Benchmarking quantized model....." )
115
- result .model_inference_time_in_ms = model_inference_time_in_ms (
184
+ # Measure inference time for quantized model
185
+ print ("Benchmarking eager quantized model....." )
186
+ result .quantized_model_eager_inference_time_in_ms = model_inference_time_in_ms (
116
187
model = m_copy , input_data = input_data
117
188
)
118
189
119
- # Calculate speedup w.r.t. baseline
120
- result .speedup = round (
121
- result .baseline_inference_time_in_ms / result .model_inference_time_in_ms , 2
190
+ # Measure inference time for compiled quantized model
191
+ print ("Benchmarking quantized model....." )
192
+ m_copy = torch .compile (m_copy , mode = config .torch_compile_mode , fullgraph = True )
193
+ result .quantized_model_compiled_inference_time_in_ms = (
194
+ model_inference_time_in_ms (model = m_copy , input_data = input_data )
195
+ )
196
+
197
+ # Compute eager speedup relative to baseline
198
+ result .eager_speedup_on_baseline = round (
199
+ result .baseline_model_eager_inference_time_in_ms
200
+ / result .quantized_model_eager_inference_time_in_ms ,
201
+ ndigits = 2 ,
202
+ )
203
+ # Compute compile speedup relative to baseline
204
+ result .compile_speedup_on_baseline = round (
205
+ result .baseline_model_compiled_inference_time_in_ms
206
+ / result .quantized_model_compiled_inference_time_in_ms ,
207
+ ndigits = 2 ,
208
+ )
209
+ # Compute compile speedup for quantized model relative to eager quantized model
210
+ result .compile_speedup_on_eager = round (
211
+ result .quantized_model_eager_inference_time_in_ms
212
+ / result .quantized_model_compiled_inference_time_in_ms ,
213
+ ndigits = 2 ,
122
214
)
123
215
124
216
# Run profiler if enabled
@@ -165,9 +257,9 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
165
257
result .memory_profile_path
166
258
)
167
259
except ValueError as e :
168
- if "not enough values to unpack" in e :
260
+ if "not enough values to unpack" in str ( e ) :
169
261
print (
170
- "Failed due to existing bugs, re- run the code to generate memory profile. Please raise an issue if it persists."
262
+ "Failed due to existing bugs, re‑ run the code to generate memory profile. Please raise an issue if it persists."
171
263
)
172
264
except Exception as e :
173
265
print (f"Error running memory profiler: { e } " )
0 commit comments