|
16 | 16 |
|
17 | 17 | import numpy as np |
18 | 18 | import torch |
| 19 | +import tosa_reference_model |
19 | 20 |
|
20 | 21 | from executorch.backends.arm.test.conftest import arm_test_options, is_option_enabled |
21 | 22 |
|
22 | 23 | from torch.export import ExportedProgram |
23 | 24 | from torch.fx.node import Node |
| 25 | +from tosa import TosaGraph |
24 | 26 |
|
25 | 27 | logger = logging.getLogger(__name__) |
26 | | -logger.setLevel(logging.WARNING) |
| 28 | +logger.setLevel(logging.CRITICAL) |
27 | 29 |
|
28 | 30 |
|
29 | 31 | class QuantizationParams: |
@@ -169,7 +171,7 @@ def __init__( |
169 | 171 | ): |
170 | 172 | self.intermediate_path = intermediate_path |
171 | 173 | self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" |
172 | | - assert os.path.exists( |
| 174 | + assert self.intermediate_path is None or os.path.exists( |
173 | 175 | self.intermediate_path |
174 | 176 | ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" |
175 | 177 |
|
@@ -332,7 +334,46 @@ def run_corstone( |
332 | 334 | tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) |
333 | 335 | output_shape = self.output_node.args[0][0].meta["val"].shape |
334 | 336 | tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) |
335 | | - return [tosa_ref_output] |
| 337 | + return tosa_ref_output |
| 338 | + |
| 339 | + def run_tosa_graph( |
| 340 | + self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] |
| 341 | + ) -> torch.Tensor: |
| 342 | + """Runs the TOSA reference model with inputs and returns the result.""" |
| 343 | + data_np = [ |
| 344 | + prep_data_for_save( |
| 345 | + input, self.is_quantized, self.input_names[i], self.qp_input[i] |
| 346 | + ) |
| 347 | + for i, input in enumerate(inputs) |
| 348 | + ] |
| 349 | + # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. |
| 350 | + tosa_profile = 0 if self.is_quantized else 1 |
| 351 | + debug_mode = "ALL" if logger.level <= logging.DEBUG else None |
| 352 | + outputs, status = tosa_reference_model.run( |
| 353 | + graph, |
| 354 | + data_np, |
| 355 | + verbosity=_tosa_refmodel_loglevel(logger.level), |
| 356 | + tosa_profile=tosa_profile, |
| 357 | + initialize_variable_tensor_from_numpy=1, # True |
| 358 | + debug_mode=debug_mode, |
| 359 | + ) |
| 360 | + |
| 361 | + assert ( |
| 362 | + status == tosa_reference_model.GraphStatus.TOSA_VALID |
| 363 | + ), "Non-valid TOSA given to reference model." |
| 364 | + |
| 365 | + outputs_torch = [] |
| 366 | + for output in outputs: |
| 367 | + output = torch.from_numpy(output) |
| 368 | + if self.is_quantized: |
| 369 | + # Need to dequant back to FP32 for comparison with torch output |
| 370 | + quant_param = self.qp_output |
| 371 | + assert ( |
| 372 | + quant_param is not None |
| 373 | + ), "There are no quantization parameters, check output parameters" |
| 374 | + output = (output.to(torch.float32) - quant_param.zp) * quant_param.scale |
| 375 | + outputs_torch.append(output) |
| 376 | + return tuple(outputs_torch) |
336 | 377 |
|
337 | 378 | def run_tosa_ref_model( |
338 | 379 | self, |
@@ -417,21 +458,13 @@ def run_tosa_ref_model( |
417 | 458 | assert ( |
418 | 459 | shutil.which(self.tosa_ref_model_path) is not None |
419 | 460 | ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}" |
420 | | - loglevel_map = { |
421 | | - logging.INFO: "INFO", |
422 | | - logging.CRITICAL: "LOW", |
423 | | - logging.ERROR: "LOW", |
424 | | - logging.WARNING: "MED", |
425 | | - logging.DEBUG: "HIGH", |
426 | | - logging.NOTSET: "MED", |
427 | | - } |
428 | | - clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0) |
| 461 | + |
429 | 462 | cmd_ref_model = [ |
430 | 463 | self.tosa_ref_model_path, |
431 | 464 | "--test_desc", |
432 | 465 | desc_file_path, |
433 | 466 | "-l", |
434 | | - loglevel_map[clamped_logging_level], |
| 467 | + _tosa_refmodel_loglevel(logger.level), |
435 | 468 | ] |
436 | 469 | _run_cmd(cmd_ref_model) |
437 | 470 |
|
@@ -467,7 +500,10 @@ def run_tosa_ref_model( |
467 | 500 |
|
468 | 501 |
|
469 | 502 | def prep_data_for_save( |
470 | | - data, is_quantized: bool, input_name: str, quant_param: QuantizationParams |
| 503 | + data: torch.Tensor, |
| 504 | + is_quantized: bool, |
| 505 | + input_name: str, |
| 506 | + quant_param: QuantizationParams, |
471 | 507 | ): |
472 | 508 | data_np = np.array(data.detach(), order="C").astype( |
473 | 509 | f"{data.dtype}".replace("torch.", "") |
@@ -576,7 +612,6 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: |
576 | 612 | assert os.path.exists( |
577 | 613 | tosa_schema_file |
578 | 614 | ), f"tosa_schema_file: {tosa_schema_file} does not exist" |
579 | | - |
580 | 615 | assert shutil.which("flatc") is not None |
581 | 616 | cmd_flatc = [ |
582 | 617 | "flatc", |
@@ -611,3 +646,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: |
611 | 646 | pass |
612 | 647 |
|
613 | 648 | return json_out |
| 649 | + |
| 650 | + |
| 651 | +def _tosa_refmodel_loglevel(loglevel: int) -> str: |
| 652 | + """Converts a logging loglevel to tosa_reference_model logginglevel, |
| 653 | + returned as string. |
| 654 | + """ |
| 655 | + loglevel_map = { |
| 656 | + logging.INFO: "INFO", |
| 657 | + logging.CRITICAL: "LOW", |
| 658 | + logging.ERROR: "LOW", |
| 659 | + logging.WARNING: "MED", |
| 660 | + logging.DEBUG: "HIGH", |
| 661 | + logging.NOTSET: "MED", |
| 662 | + } |
| 663 | + clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) |
| 664 | + return loglevel_map[clamped_logging_level] |
0 commit comments