|
17 | 17 | import numpy as np |
18 | 18 | import torch |
19 | 19 |
|
20 | | -import tosa_reference_model |
21 | | - |
22 | 20 | from torch.export import ExportedProgram |
23 | 21 | from torch.fx.node import Node |
24 | | -from tosa import TosaGraph |
25 | 22 |
|
26 | 23 | logger = logging.getLogger(__name__) |
27 | | -logger.setLevel(logging.CRITICAL) |
| 24 | +logger.setLevel(logging.WARNING) |
28 | 25 |
|
29 | 26 |
|
30 | 27 | class QuantizationParams: |
@@ -170,7 +167,7 @@ def __init__( |
170 | 167 | ): |
171 | 168 | self.intermediate_path = intermediate_path |
172 | 169 | self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" |
173 | | - assert self.intermediate_path is None or os.path.exists( |
| 170 | + assert os.path.exists( |
174 | 171 | self.intermediate_path |
175 | 172 | ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" |
176 | 173 |
|
@@ -326,46 +323,7 @@ def run_corstone( |
326 | 323 | tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) |
327 | 324 | output_shape = self.output_node.args[0][0].meta["val"].shape |
328 | 325 | tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) |
329 | | - return tosa_ref_output |
330 | | - |
331 | | - def run_tosa_graph( |
332 | | - self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] |
333 | | - ) -> torch.Tensor: |
334 | | - """Runs the TOSA reference model with inputs and returns the result.""" |
335 | | - data_np = [ |
336 | | - prep_data_for_save( |
337 | | - input, self.is_quantized, self.input_names[i], self.qp_input[i] |
338 | | - ) |
339 | | - for i, input in enumerate(inputs) |
340 | | - ] |
341 | | - # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. |
342 | | - tosa_profile = 0 if self.is_quantized else 1 |
343 | | - debug_mode = "ALL" if logger.level <= logging.DEBUG else None |
344 | | - outputs, status = tosa_reference_model.run( |
345 | | - graph, |
346 | | - data_np, |
347 | | - verbosity=_tosa_refmodel_loglevel(logger.level), |
348 | | - tosa_profile=tosa_profile, |
349 | | - initialize_variable_tensor_from_numpy=1, # True |
350 | | - debug_mode=debug_mode, |
351 | | - ) |
352 | | - |
353 | | - assert ( |
354 | | - status == tosa_reference_model.GraphStatus.TOSA_VALID |
355 | | - ), "Non-valid TOSA given to reference model." |
356 | | - |
357 | | - outputs_torch = [] |
358 | | - for output in outputs: |
359 | | - output = output.astype(np.float32) |
360 | | - if self.is_quantized: |
361 | | - # Need to dequant back to FP32 for comparison with torch output |
362 | | - quant_param = self.qp_output |
363 | | - assert ( |
364 | | - quant_param is not None |
365 | | - ), "There are no quantization parameters, check output parameters" |
366 | | - output = (output - quant_param.zp) * quant_param.scale |
367 | | - outputs_torch.append(torch.from_numpy(output)) |
368 | | - return tuple(outputs_torch) |
| 326 | + return [tosa_ref_output] |
369 | 327 |
|
370 | 328 | def run_tosa_ref_model( |
371 | 329 | self, |
@@ -450,13 +408,21 @@ def run_tosa_ref_model( |
450 | 408 | assert ( |
451 | 409 | shutil.which(self.tosa_ref_model_path) is not None |
452 | 410 | ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}" |
453 | | - |
| 411 | + loglevel_map = { |
| 412 | + logging.INFO: "INFO", |
| 413 | + logging.CRITICAL: "LOW", |
| 414 | + logging.ERROR: "LOW", |
| 415 | + logging.WARNING: "MED", |
| 416 | + logging.DEBUG: "HIGH", |
| 417 | + logging.NOTSET: "MED", |
| 418 | + } |
| 419 | + clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0) |
454 | 420 | cmd_ref_model = [ |
455 | 421 | self.tosa_ref_model_path, |
456 | 422 | "--test_desc", |
457 | 423 | desc_file_path, |
458 | 424 | "-l", |
459 | | - _tosa_refmodel_loglevel(logger.level), |
| 425 | + loglevel_map[clamped_logging_level], |
460 | 426 | ] |
461 | 427 | _run_cmd(cmd_ref_model) |
462 | 428 |
|
@@ -492,10 +458,7 @@ def run_tosa_ref_model( |
492 | 458 |
|
493 | 459 |
|
494 | 460 | def prep_data_for_save( |
495 | | - data: torch.Tensor, |
496 | | - is_quantized: bool, |
497 | | - input_name: str, |
498 | | - quant_param: QuantizationParams, |
| 461 | + data, is_quantized: bool, input_name: str, quant_param: QuantizationParams |
499 | 462 | ): |
500 | 463 | data_np = np.array(data.detach(), order="C").astype( |
501 | 464 | f"{data.dtype}".replace("torch.", "") |
@@ -639,19 +602,3 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: |
639 | 602 | pass |
640 | 603 |
|
641 | 604 | return json_out |
642 | | - |
643 | | - |
644 | | -def _tosa_refmodel_loglevel(loglevel: int) -> str: |
645 | | - """Converts a logging loglevel to tosa_reference_model logginglevel, |
646 | | - returned as string. |
647 | | - """ |
648 | | - loglevel_map = { |
649 | | - logging.INFO: "INFO", |
650 | | - logging.CRITICAL: "LOW", |
651 | | - logging.ERROR: "LOW", |
652 | | - logging.WARNING: "MED", |
653 | | - logging.DEBUG: "HIGH", |
654 | | - logging.NOTSET: "MED", |
655 | | - } |
656 | | - clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) |
657 | | - return loglevel_map[clamped_logging_level] |
0 commit comments