|
13 | 13 | from hivemind.utils.logging import get_logger |
14 | 14 | from transformers import PretrainedConfig |
15 | 15 |
|
16 | | -from petals.server.block_utils import resolve_block_dtype |
| 16 | +from petals.server.block_utils import get_model_block, resolve_block_dtype |
17 | 17 | from petals.utils.convert_block import QuantType, convert_block |
18 | 18 | from petals.utils.disk_cache import DEFAULT_CACHE_DIR |
| 19 | +from petals.utils.misc import DUMMY_KEY_PAST |
19 | 20 |
|
20 | 21 | logger = get_logger(__name__) |
21 | 22 |
|
@@ -201,18 +202,25 @@ def measure_compute_rps( |
201 | 202 | if not tensor_parallel_devices: |
202 | 203 | tensor_parallel_devices = (device,) |
203 | 204 | with torch.inference_mode(): |
204 | | - block = config.block_class(config).to(dtype) |
| 205 | + block = get_model_block(config) |
| 206 | + block = block.to(dtype) |
205 | 207 | block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) |
206 | 208 |
|
207 | | - cache = None |
| 209 | + cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)) |
208 | 210 | elapsed = 0 |
209 | 211 | dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) |
210 | | - _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time |
| 212 | + |
| 213 | + # Skip the 1st step to exclude the initialization time |
| 214 | + def step(cache_): |
| 215 | + outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None) |
| 216 | + return outputs[1] if inference else None |
| 217 | + |
| 218 | + cache = step(cache) |
211 | 219 | synchronize(device) |
212 | 220 |
|
213 | 221 | start_time = time.perf_counter() |
214 | 222 | for _ in range(n_steps): |
215 | | - _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) |
| 223 | + cache = step(cache) |
216 | 224 | synchronize(device) |
217 | 225 | elapsed = time.perf_counter() - start_time |
218 | 226 | device_rps = n_steps * n_tokens / elapsed |
|
0 commit comments