|
18 | 18 | import torch |
19 | 19 |
|
20 | 20 | from executorch.backends.arm.test.common import arm_test_options, is_option_enabled |
| 21 | +import tosa_reference_model |
21 | 22 |
|
22 | 23 | from torch.export import ExportedProgram |
23 | 24 | from torch.fx.node import Node |
| 25 | +from tosa import TosaGraph |
24 | 26 |
|
25 | 27 | logger = logging.getLogger(__name__) |
26 | | -logger.setLevel(logging.WARNING) |
| 28 | +logger.setLevel(logging.CRITICAL) |
27 | 29 |
|
28 | 30 |
|
29 | 31 | class QuantizationParams: |
@@ -169,7 +171,7 @@ def __init__( |
169 | 171 | ): |
170 | 172 | self.intermediate_path = intermediate_path |
171 | 173 | self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" |
172 | | - assert os.path.exists( |
| 174 | + assert self.intermediate_path is None or os.path.exists( |
173 | 175 | self.intermediate_path |
174 | 176 | ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" |
175 | 177 |
|
@@ -334,7 +336,46 @@ def run_corstone( |
334 | 336 | tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) |
335 | 337 | output_shape = self.output_node.args[0][0].meta["val"].shape |
336 | 338 | tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) |
337 | | - return [tosa_ref_output] |
| 339 | + return tosa_ref_output |
| 340 | + |
| 341 | + def run_tosa_graph( |
| 342 | + self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] |
| 343 | + ) -> torch.Tensor: |
| 344 | + """Runs the TOSA reference model with inputs and returns the result.""" |
| 345 | + data_np = [ |
| 346 | + prep_data_for_save( |
| 347 | + input, self.is_quantized, self.input_names[i], self.qp_input[i] |
| 348 | + ) |
| 349 | + for i, input in enumerate(inputs) |
| 350 | + ] |
| 351 | + # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. |
| 352 | + tosa_profile = 0 if self.is_quantized else 1 |
| 353 | + debug_mode = "ALL" if logger.level <= logging.DEBUG else None |
| 354 | + outputs, status = tosa_reference_model.run( |
| 355 | + graph, |
| 356 | + data_np, |
| 357 | + verbosity=_tosa_refmodel_loglevel(logger.level), |
| 358 | + tosa_profile=tosa_profile, |
| 359 | + initialize_variable_tensor_from_numpy=1, # True |
| 360 | + debug_mode=debug_mode, |
| 361 | + ) |
| 362 | + |
| 363 | + assert ( |
| 364 | + status == tosa_reference_model.GraphStatus.TOSA_VALID |
| 365 | + ), "Non-valid TOSA given to reference model." |
| 366 | + |
| 367 | + outputs_torch = [] |
| 368 | + for output in outputs: |
| 369 | + output = torch.from_numpy(output) |
| 370 | + if self.is_quantized: |
| 371 | + # Need to dequant back to FP32 for comparison with torch output |
| 372 | + quant_param = self.qp_output |
| 373 | + assert ( |
| 374 | + quant_param is not None |
| 375 | + ), "There are no quantization parameters, check output parameters" |
| 376 | + output = (output.to(torch.float32) - quant_param.zp) * quant_param.scale |
| 377 | + outputs_torch.append(output) |
| 378 | + return tuple(outputs_torch) |
338 | 379 |
|
339 | 380 | def run_tosa_ref_model( |
340 | 381 | self, |
@@ -419,21 +460,13 @@ def run_tosa_ref_model( |
419 | 460 | assert ( |
420 | 461 | shutil.which(self.tosa_ref_model_path) is not None |
421 | 462 | ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}" |
422 | | - loglevel_map = { |
423 | | - logging.INFO: "INFO", |
424 | | - logging.CRITICAL: "LOW", |
425 | | - logging.ERROR: "LOW", |
426 | | - logging.WARNING: "MED", |
427 | | - logging.DEBUG: "HIGH", |
428 | | - logging.NOTSET: "MED", |
429 | | - } |
430 | | - clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0) |
| 463 | + |
431 | 464 | cmd_ref_model = [ |
432 | 465 | self.tosa_ref_model_path, |
433 | 466 | "--test_desc", |
434 | 467 | desc_file_path, |
435 | 468 | "-l", |
436 | | - loglevel_map[clamped_logging_level], |
| 469 | + _tosa_refmodel_loglevel(logger.level), |
437 | 470 | ] |
438 | 471 | _run_cmd(cmd_ref_model) |
439 | 472 |
|
@@ -469,7 +502,10 @@ def run_tosa_ref_model( |
469 | 502 |
|
470 | 503 |
|
471 | 504 | def prep_data_for_save( |
472 | | - data, is_quantized: bool, input_name: str, quant_param: QuantizationParams |
| 505 | + data: torch.Tensor, |
| 506 | + is_quantized: bool, |
| 507 | + input_name: str, |
| 508 | + quant_param: QuantizationParams, |
473 | 509 | ): |
474 | 510 | data_np = np.array(data.detach(), order="C").astype( |
475 | 511 | f"{data.dtype}".replace("torch.", "") |
@@ -578,7 +614,6 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: |
578 | 614 | assert os.path.exists( |
579 | 615 | tosa_schema_file |
580 | 616 | ), f"tosa_schema_file: {tosa_schema_file} does not exist" |
581 | | - |
582 | 617 | assert shutil.which("flatc") is not None |
583 | 618 | cmd_flatc = [ |
584 | 619 | "flatc", |
@@ -613,3 +648,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: |
613 | 648 | pass |
614 | 649 |
|
615 | 650 | return json_out |
| 651 | + |
| 652 | + |
| 653 | +def _tosa_refmodel_loglevel(loglevel: int) -> str: |
| 654 | + """Converts a logging loglevel to tosa_reference_model logginglevel, |
| 655 | + returned as string. |
| 656 | + """ |
| 657 | + loglevel_map = { |
| 658 | + logging.INFO: "INFO", |
| 659 | + logging.CRITICAL: "LOW", |
| 660 | + logging.ERROR: "LOW", |
| 661 | + logging.WARNING: "MED", |
| 662 | + logging.DEBUG: "HIGH", |
| 663 | + logging.NOTSET: "MED", |
| 664 | + } |
| 665 | + clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) |
| 666 | + return loglevel_map[clamped_logging_level] |
0 commit comments