|
17 | 17 | import numpy as np |
18 | 18 | import torch |
19 | 19 |
|
| 20 | +import tosa_reference_model |
| 21 | + |
20 | 22 | from torch.export import ExportedProgram |
21 | 23 | from torch.fx.node import Node |
| 24 | +from tosa import TosaGraph |
22 | 25 |
|
23 | 26 | logger = logging.getLogger(__name__) |
24 | | -logger.setLevel(logging.WARNING) |
| 27 | +logger.setLevel(logging.CRITICAL) |
25 | 28 |
|
26 | 29 |
|
27 | 30 | class QuantizationParams: |
@@ -167,7 +170,7 @@ def __init__( |
167 | 170 | ): |
168 | 171 | self.intermediate_path = intermediate_path |
169 | 172 | self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" |
170 | | - assert os.path.exists( |
| 173 | + assert self.intermediate_path is None or os.path.exists( |
171 | 174 | self.intermediate_path |
172 | 175 | ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" |
173 | 176 |
|
@@ -323,7 +326,46 @@ def run_corstone( |
323 | 326 | tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) |
324 | 327 | output_shape = self.output_node.args[0][0].meta["val"].shape |
325 | 328 | tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) |
326 | | - return [tosa_ref_output] |
| 329 | + return tosa_ref_output |
| 330 | + |
| 331 | + def run_tosa_graph( |
| 332 | + self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] |
| 333 | + ) -> torch.Tensor: |
| 334 | + """Runs the TOSA reference model with inputs and returns the result.""" |
| 335 | + data_np = [ |
| 336 | + prep_data_for_save( |
| 337 | + input, self.is_quantized, self.input_names[i], self.qp_input[i] |
| 338 | + ) |
| 339 | + for i, input in enumerate(inputs) |
| 340 | + ] |
| 341 | + # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. |
| 342 | + tosa_profile = 0 if self.is_quantized else 1 |
| 343 | + debug_mode = "ALL" if logger.level <= logging.DEBUG else None |
| 344 | + outputs, status = tosa_reference_model.run( |
| 345 | + graph, |
| 346 | + data_np, |
| 347 | + verbosity=_tosa_refmodel_loglevel(logger.level), |
| 348 | + tosa_profile=tosa_profile, |
| 349 | + initialize_variable_tensor_from_numpy=1, # True |
| 350 | + debug_mode=debug_mode, |
| 351 | + ) |
| 352 | + |
| 353 | + assert ( |
| 354 | + status == tosa_reference_model.GraphStatus.TOSA_VALID |
| 355 | + ), "Non-valid TOSA given to reference model." |
| 356 | + |
| 357 | + outputs_torch = [] |
| 358 | + for output in outputs: |
| 359 | + output = output.astype(np.float32) |
| 360 | + if self.is_quantized: |
| 361 | + # Need to dequant back to FP32 for comparison with torch output |
| 362 | + quant_param = self.qp_output |
| 363 | + assert ( |
| 364 | + quant_param is not None |
| 365 | + ), "There are no quantization parameters, check output parameters" |
| 366 | + output = (output - quant_param.zp) * quant_param.scale |
| 367 | + outputs_torch.append(torch.from_numpy(output)) |
| 368 | + return tuple(outputs_torch) |
327 | 369 |
|
328 | 370 | def run_tosa_ref_model( |
329 | 371 | self, |
@@ -408,21 +450,13 @@ def run_tosa_ref_model( |
408 | 450 | assert ( |
409 | 451 | shutil.which(self.tosa_ref_model_path) is not None |
410 | 452 | ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}" |
411 | | - loglevel_map = { |
412 | | - logging.INFO: "INFO", |
413 | | - logging.CRITICAL: "LOW", |
414 | | - logging.ERROR: "LOW", |
415 | | - logging.WARNING: "MED", |
416 | | - logging.DEBUG: "HIGH", |
417 | | - logging.NOTSET: "MED", |
418 | | - } |
419 | | - clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0) |
| 453 | + |
420 | 454 | cmd_ref_model = [ |
421 | 455 | self.tosa_ref_model_path, |
422 | 456 | "--test_desc", |
423 | 457 | desc_file_path, |
424 | 458 | "-l", |
425 | | - loglevel_map[clamped_logging_level], |
| 459 | + _tosa_refmodel_loglevel(logger.level), |
426 | 460 | ] |
427 | 461 | _run_cmd(cmd_ref_model) |
428 | 462 |
|
@@ -455,7 +489,10 @@ def run_tosa_ref_model( |
455 | 489 |
|
456 | 490 |
|
457 | 491 | def prep_data_for_save( |
458 | | - data, is_quantized: bool, input_name: str, quant_param: QuantizationParams |
| 492 | + data: torch.Tensor, |
| 493 | + is_quantized: bool, |
| 494 | + input_name: str, |
| 495 | + quant_param: QuantizationParams, |
459 | 496 | ): |
460 | 497 | data_np = np.array(data.detach(), order="C").astype(np.float32) |
461 | 498 |
|
@@ -597,3 +634,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: |
597 | 634 | pass |
598 | 635 |
|
599 | 636 | return json_out |
| 637 | + |
| 638 | + |
| 639 | +def _tosa_refmodel_loglevel(loglevel: int) -> str: |
| 640 | + """Converts a logging loglevel to tosa_reference_model logginglevel, |
| 641 | + returned as string. |
| 642 | + """ |
| 643 | + loglevel_map = { |
| 644 | + logging.INFO: "INFO", |
| 645 | + logging.CRITICAL: "LOW", |
| 646 | + logging.ERROR: "LOW", |
| 647 | + logging.WARNING: "MED", |
| 648 | + logging.DEBUG: "HIGH", |
| 649 | + logging.NOTSET: "MED", |
| 650 | + } |
| 651 | + clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) |
| 652 | + return loglevel_map[clamped_logging_level] |
0 commit comments