|
21 | 21 | through type checking and cleanup of redundant operations. |
22 | 22 | """ |
23 | 23 |
|
24 | | -from collections import namedtuple |
| 24 | +from collections import defaultdict, namedtuple |
25 | 25 | from copy import deepcopy |
| 26 | +from dataclasses import dataclass, field |
26 | 27 |
|
27 | 28 | import ml_dtypes |
28 | 29 | import numpy as np |
|
39 | 40 |
|
40 | 41 | PrecisionTypes = namedtuple("PrecisionTypes", ["onnx_type", "numpy_type", "str_short", "str_full"]) |
41 | 42 |
|
| 43 | + |
| 44 | +@dataclass |
| 45 | +class InputIndexTracker: |
| 46 | + """A class that tracks the index of an input to a node.""" |
| 47 | + |
| 48 | + node: onnx.NodeProto |
| 49 | + node_index: int |
| 50 | + |
| 51 | + |
| 52 | +@dataclass |
| 53 | +class InitializerConsumerTracker: |
| 54 | + """A class that tracks the nodes that consume an initializer.""" |
| 55 | + |
| 56 | + low_precision_nodes: list[InputIndexTracker] = field(default_factory=list) |
| 57 | + high_precision_nodes: list[InputIndexTracker] = field(default_factory=list) |
| 58 | + |
| 59 | + |
42 | 60 | PRECISION_MAP = { |
43 | 61 | "fp32": PrecisionTypes(TensorProto.FLOAT, np.float32, "fp32", "float32"), |
44 | 62 | "fp16": PrecisionTypes(TensorProto.FLOAT16, np.float16, "fp16", "float16"), |
@@ -472,133 +490,247 @@ def _get_tensors_to_cast( |
472 | 490 | def _convert_initializers( |
473 | 491 | self, low_precision_nodes: list[str], high_precision_nodes: list[str] |
474 | 492 | ) -> onnx.ModelProto: |
475 | | - def convert_initializer( |
476 | | - init: onnx.TensorProto, |
477 | | - node: onnx.NodeProto, |
478 | | - from_type: PrecisionTypes, |
479 | | - to_type: PrecisionTypes, |
480 | | - ): |
481 | | - if init.data_type != from_type.onnx_type: |
| 493 | + """Convert model initializers to appropriate precision based on their consumer nodes. |
| 494 | +
|
| 495 | + This method analyzes how each initializer is used by different precision nodes and converts |
| 496 | + or duplicates initializers as needed to ensure type compatibility: |
| 497 | +
|
| 498 | + 1. Maps each initializer to the high/low precision nodes that consume it |
| 499 | + 2. For each initializer, applies one of these strategies: |
| 500 | + - If only used by low precision nodes: convert to low precision |
| 501 | + - If only used by high precision nodes: convert to high precision |
| 502 | + - If used by both precision types: duplicate the initializer, creating separate |
| 503 | + copies for each precision type and updating node references accordingly |
| 504 | + 3. Skips conversion for non-float initializers or those already at correct precision |
| 505 | +
|
| 506 | + The method handles special cases like bfloat16 conversion and provides warnings when |
| 507 | + values are clamped or replaced due to precision limits. |
| 508 | +
|
| 509 | + Args: |
| 510 | + low_precision_nodes: List of node names that should use low precision initializers. |
| 511 | + high_precision_nodes: List of node names that should use high precision initializers. |
| 512 | + """ |
| 513 | + # 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them. |
| 514 | + low_precision_nodes_set: set[str] = set(low_precision_nodes) |
| 515 | + high_precision_nodes_set: set[str] = set(high_precision_nodes) |
| 516 | + initializer_to_nodes: dict[str, InitializerConsumerTracker] = defaultdict( |
| 517 | + lambda: InitializerConsumerTracker() |
| 518 | + ) |
| 519 | + for node in self.model.graph.node: |
| 520 | + # Compute the mapping from initializers to low precision nodes that use them. |
| 521 | + if node.name in low_precision_nodes_set: |
| 522 | + for idx, input_name in enumerate(node.input): |
| 523 | + if input_name in self.initializer_map: |
| 524 | + if self._should_skip_low_precision_input_conversion(node, input_name): |
| 525 | + # Handle low precision nodes that require certain high precision inputs. |
| 526 | + initializer_to_nodes[input_name].high_precision_nodes.append( |
| 527 | + InputIndexTracker(node=node, node_index=idx) |
| 528 | + ) |
| 529 | + else: |
| 530 | + initializer_to_nodes[input_name].low_precision_nodes.append( |
| 531 | + InputIndexTracker(node=node, node_index=idx) |
| 532 | + ) |
| 533 | + # Compute the mapping from initializers to high precision nodes that use them. |
| 534 | + elif node.name in high_precision_nodes_set: |
| 535 | + for idx, input_name in enumerate(node.input): |
| 536 | + if input_name in self.initializer_map: |
| 537 | + initializer_to_nodes[input_name].high_precision_nodes.append( |
| 538 | + InputIndexTracker(node=node, node_index=idx) |
| 539 | + ) |
| 540 | + |
| 541 | + onnx_float_types = set(ONNX_TYPES) |
| 542 | + # 2. Convert initializers to appropriate precision based on their consumer nodes. |
| 543 | + for init_name, tracker in initializer_to_nodes.items(): |
| 544 | + # Get the initializer. |
| 545 | + init = self.initializer_map[init_name] |
| 546 | + # If not used, just skip. |
| 547 | + if len(tracker.low_precision_nodes) == 0 and len(tracker.high_precision_nodes) == 0: |
| 548 | + logger.debug(f"Initializer {init_name} is not used by any nodes, skipping") |
| 549 | + continue |
| 550 | + # If the initializer is not a float, then just skip. |
| 551 | + if init.data_type not in onnx_float_types: |
| 552 | + logger.debug(f"Initializer {init_name} is not a float, skipping") |
| 553 | + continue |
| 554 | + # If the initializer is only used by high precision nodes and is high precision, then just skip. |
| 555 | + if ( |
| 556 | + len(tracker.low_precision_nodes) == 0 |
| 557 | + and init.data_type == self.high_precision_type.onnx_type |
| 558 | + ): |
482 | 559 | logger.debug( |
483 | | - f"Initializer {init.name} has data type {init.data_type}, and size {len(init.raw_data)}," |
484 | | - "skipping conversion" |
| 560 | + f"Initializer {init_name} is already high precision and only used " |
| 561 | + "by high precision nodes, skipping" |
485 | 562 | ) |
486 | | - return False |
| 563 | + continue |
| 564 | + # If the initializer is only used by low precision nodes and is low precision, then just skip. |
| 565 | + if ( |
| 566 | + len(tracker.high_precision_nodes) == 0 |
| 567 | + and init.data_type == self.low_precision_type.onnx_type |
| 568 | + ): |
| 569 | + logger.debug( |
| 570 | + f"Initializer {init_name} is already low precision and only used " |
| 571 | + "by low precision nodes, skipping" |
| 572 | + ) |
| 573 | + continue |
| 574 | + |
| 575 | + # If the initializer is used by only one precision type, then convert it to the other precision type. |
| 576 | + if len(tracker.high_precision_nodes) == 0 or len(tracker.low_precision_nodes) == 0: |
| 577 | + if len(tracker.low_precision_nodes) > 0: |
| 578 | + logger.debug( |
| 579 | + f"Convert initializer {init_name} to " |
| 580 | + f"{self.low_precision_type.str_short}, only used by low precision nodes" |
| 581 | + ) |
| 582 | + from_type = self.high_precision_type |
| 583 | + to_type = self.low_precision_type |
| 584 | + elif len(tracker.high_precision_nodes) > 0: |
| 585 | + logger.debug( |
| 586 | + f"Convert initializer {init_name} to " |
| 587 | + f"{self.high_precision_type.str_short}, " |
| 588 | + "only used by high precision nodes" |
| 589 | + ) |
| 590 | + from_type = self.low_precision_type |
| 591 | + to_type = self.high_precision_type |
| 592 | + else: |
| 593 | + raise ValueError( |
| 594 | + f"Unexpected: initializer {init_name} is not used by any " |
| 595 | + "nodes and is not a float" |
| 596 | + ) |
| 597 | + |
| 598 | + new_init = self._cast_initializer( |
| 599 | + init=init, |
| 600 | + from_type=from_type, |
| 601 | + to_type=to_type, |
| 602 | + low_precision_nodes=tracker.low_precision_nodes, |
| 603 | + high_precision_nodes=tracker.high_precision_nodes, |
| 604 | + ) |
| 605 | + if new_init is not None: |
| 606 | + self.model.graph.initializer.remove(init) |
| 607 | + self.model.graph.initializer.extend([new_init]) |
| 608 | + continue |
487 | 609 |
|
488 | | - # If initializer is too large, skip conversion, perform cast instead |
489 | | - if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes: |
| 610 | + # This initializer is used by both high precision and low precision nodes, so we need |
| 611 | + # to duplicate it for low precision nodes. |
| 612 | + assert len(tracker.low_precision_nodes) > 0 and len(tracker.high_precision_nodes) > 0 |
| 613 | + if init.data_type == self.low_precision_type.onnx_type: |
490 | 614 | logger.debug( |
491 | | - f"Initializer {init.name} is too large, skipping initializer conversion, cast in " |
492 | | - "runtime instead" |
| 615 | + f"Convert initializer {init_name} to " |
| 616 | + f"{self.high_precision_type.str_short}, " |
| 617 | + "used by both high precision and low precision nodes" |
493 | 618 | ) |
494 | | - exclude_consumers = ( |
495 | | - low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes |
| 619 | + from_type = self.low_precision_type |
| 620 | + to_type = self.high_precision_type |
| 621 | + nodes_to_update = tracker.high_precision_nodes |
| 622 | + elif init.data_type == self.high_precision_type.onnx_type: |
| 623 | + logger.debug( |
| 624 | + f"Convert initializer {init_name} to " |
| 625 | + f"{self.low_precision_type.str_short}, " |
| 626 | + "used by both high precision and low precision nodes" |
496 | 627 | ) |
497 | | - self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers) |
498 | | - return True |
499 | | - try: |
500 | | - np_array = numpy_helper.to_array(init) |
501 | | - assert from_type.str_short in PRECISION_MAP |
502 | | - assert to_type.str_short in PRECISION_MAP |
503 | | - assert from_type.str_short != to_type.str_short |
504 | | - |
505 | | - if np_array.dtype == from_type.numpy_type: |
506 | | - consumers = [n.name for n in utils.get_consumer_nodes(self.model, init.name)] |
507 | | - should_duplicate = len(consumers) > 1 and set(consumers) & set( |
508 | | - high_precision_nodes |
509 | | - ) |
| 628 | + from_type = self.high_precision_type |
| 629 | + to_type = self.low_precision_type |
| 630 | + nodes_to_update = tracker.low_precision_nodes |
| 631 | + else: |
| 632 | + raise ValueError(f"Unexpected: initializer {init_name} is not a float") |
| 633 | + |
| 634 | + new_init = self._cast_initializer( |
| 635 | + init=init, |
| 636 | + from_type=from_type, |
| 637 | + to_type=to_type, |
| 638 | + low_precision_nodes=tracker.low_precision_nodes, |
| 639 | + high_precision_nodes=tracker.high_precision_nodes, |
| 640 | + ) |
| 641 | + if new_init is not None: |
| 642 | + new_init_name = f"{init_name}_{to_type.str_short}" |
| 643 | + new_init.name = new_init_name |
| 644 | + for node in nodes_to_update: |
| 645 | + node.node.input[node.node_index] = new_init_name |
| 646 | + self.model.graph.initializer.extend([new_init]) |
| 647 | + |
| 648 | + def _cast_initializer( |
| 649 | + self, |
| 650 | + init: onnx.TensorProto, |
| 651 | + from_type: PrecisionTypes, |
| 652 | + to_type: PrecisionTypes, |
| 653 | + low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto], |
| 654 | + high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto], |
| 655 | + ) -> onnx.TensorProto | None: |
| 656 | + """Cast an initializer to a new precision based on its consumer nodes. |
510 | 657 |
|
511 | | - if should_duplicate: |
512 | | - # Create a new low precision copy with a different name |
513 | | - new_name = f"{init.name}_{to_type.str_short}" |
514 | | - logger.debug( |
515 | | - f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due " |
516 | | - f"to node {node.name}" |
517 | | - ) |
| 658 | + This method converts an initializer to a new precision while handling special cases like bfloat16 conversion |
| 659 | + and providing warnings when values are clamped or replaced due to precision limits. |
518 | 660 |
|
519 | | - # Update the node to use the new initializer |
520 | | - for i, input_name in enumerate(node.input): |
521 | | - if input_name == init.name: |
522 | | - node.input[i] = new_name |
523 | | - break |
| 661 | + Args: |
| 662 | + init: The initializer to cast. |
| 663 | + from_type: The original precision of the initializer. |
| 664 | + to_type: The new precision to cast the initializer to. |
524 | 665 |
|
525 | | - if init.name in initializer_converted_dup: |
526 | | - return False |
527 | | - initializer_converted_dup.append(init.name) |
528 | | - else: |
529 | | - if init.name in initializer_converted: |
530 | | - return False |
531 | | - new_name = init.name |
532 | | - logger.debug( |
533 | | - f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}" |
534 | | - ) |
535 | | - initializer_converted.append(init.name) |
536 | | - self.model.graph.initializer.remove(init) |
537 | | - |
538 | | - # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead |
539 | | - if self._is_bf16(to_type) and self._is_fp32(from_type): |
540 | | - new_init = onnx.TensorProto() |
541 | | - new_init.dims.extend(np_array.shape) |
542 | | - new_init.name = new_name |
543 | | - new_init.data_type = onnx.TensorProto.BFLOAT16 |
544 | | - bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) |
545 | | - new_init.raw_data = bf16_bytes.tobytes() |
546 | | - else: |
547 | | - assert to_type.numpy_type is not None |
548 | | - data_max, data_lowest = ( |
549 | | - np.finfo(to_type.numpy_type).max, |
550 | | - np.finfo(to_type.numpy_type).smallest_subnormal, |
551 | | - ) |
552 | | - if np.any(np.abs(np_array) > data_max): |
553 | | - logger.warning( |
554 | | - f"Initializer {init.name} used by node {node.name} contains values larger than " |
555 | | - f"largest {to_type.str_short} value, values will be clamped to {data_max}." |
556 | | - ) |
557 | | - np_array = np.clip(np_array, -1 * data_max, data_max) |
558 | | - if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): |
559 | | - logger.warning( |
560 | | - f"Initializer {init.name} used by node {node.name} contains values smaller than " |
561 | | - f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}." |
562 | | - ) |
563 | | - np_array = np.where( |
564 | | - (np_array != 0.0) & (np.abs(np_array) < data_lowest), |
565 | | - data_lowest, |
566 | | - np_array, |
567 | | - ) |
568 | | - new_array = np_array.astype(to_type.numpy_type) |
569 | | - new_init = numpy_helper.from_array(new_array, new_name) |
570 | | - self.model.graph.initializer.extend([new_init]) |
571 | | - return True |
572 | | - return False |
573 | | - except Exception as e: |
574 | | - logger.error(f"Error converting initializer {init.name}: {e}") |
575 | | - return False |
| 666 | + Returns: |
| 667 | + onnx.TensorProto: The casted initializer. |
| 668 | + """ |
576 | 669 |
|
577 | | - initializer_converted = [] |
578 | | - initializer_converted_dup = [] |
579 | | - modified = False |
580 | | - for node in self.model.graph.node: |
581 | | - if node.name in low_precision_nodes: |
582 | | - for init in self.node_to_init_map[node.name]: |
583 | | - if self._should_skip_low_precision_input_conversion(node, init.name): |
584 | | - continue |
585 | | - modified |= convert_initializer( |
586 | | - init, |
587 | | - node, |
588 | | - from_type=self.high_precision_type, |
589 | | - to_type=self.low_precision_type, |
590 | | - ) |
591 | | - if modified: |
592 | | - _, _, self.node_to_init_map = utils.setup_mappings(self.model) |
593 | | - |
594 | | - if node.name in high_precision_nodes: |
595 | | - for init in self.node_to_init_map[node.name]: |
596 | | - convert_initializer( |
597 | | - init, |
598 | | - node, |
599 | | - from_type=self.low_precision_type, |
600 | | - to_type=self.high_precision_type, |
601 | | - ) |
| 670 | + def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str: |
| 671 | + """Get the name of a node or input index tracker.""" |
| 672 | + if isinstance(node, onnx.NodeProto): |
| 673 | + return node.name |
| 674 | + elif isinstance(node, InputIndexTracker): |
| 675 | + return node.node.name |
| 676 | + else: |
| 677 | + raise ValueError(f"Unexpected: {type(node)}") |
| 678 | + |
| 679 | + # Ensure the initializer is of the expected type |
| 680 | + assert init.data_type == from_type.onnx_type, ( |
| 681 | + f"Initializer {init.name} is not of type {from_type.str_short}" |
| 682 | + ) |
| 683 | + |
| 684 | + if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes: |
| 685 | + # The initializer is too large, so we need to convert it at runtime. |
| 686 | + logger.debug( |
| 687 | + f"Initializer {init.name} is too large, skipping initializer conversion, cast in " |
| 688 | + "runtime instead" |
| 689 | + ) |
| 690 | + exclude_consumers = ( |
| 691 | + low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes |
| 692 | + ) |
| 693 | + exclude_consumers_names: list[str] = [] |
| 694 | + |
| 695 | + exclude_consumers_names = [_get_name(node) for node in exclude_consumers] |
| 696 | + self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names) |
| 697 | + return None |
| 698 | + |
| 699 | + np_array = numpy_helper.to_array(init) |
| 700 | + # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead |
| 701 | + if self._is_bf16(to_type) and self._is_fp32(from_type): |
| 702 | + new_init = onnx.TensorProto() |
| 703 | + new_init.dims.extend(np_array.shape) |
| 704 | + new_init.name = init.name |
| 705 | + new_init.data_type = onnx.TensorProto.BFLOAT16 |
| 706 | + bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) |
| 707 | + new_init.raw_data = bf16_bytes.tobytes() |
| 708 | + else: |
| 709 | + assert to_type.numpy_type is not None |
| 710 | + data_max, data_lowest = ( |
| 711 | + np.finfo(to_type.numpy_type).max, |
| 712 | + np.finfo(to_type.numpy_type).smallest_subnormal, |
| 713 | + ) |
| 714 | + if np.any(np.abs(np_array) > data_max): |
| 715 | + logger.warning( |
| 716 | + f"Initializer {init.name} contains values larger than largest " |
| 717 | + f"{to_type.str_short} value, values will be clamped to {data_max}." |
| 718 | + ) |
| 719 | + np_array = np.clip(np_array, -1 * data_max, data_max) |
| 720 | + if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): |
| 721 | + logger.warning( |
| 722 | + f"Initializer {init.name} contains values smaller than smallest " |
| 723 | + f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}." |
| 724 | + ) |
| 725 | + np_array = np.where( |
| 726 | + (np_array != 0.0) & (np.abs(np_array) < data_lowest), |
| 727 | + data_lowest, |
| 728 | + np_array, |
| 729 | + ) |
| 730 | + new_array = np_array.astype(to_type.numpy_type) |
| 731 | + new_init = numpy_helper.from_array(new_array, init.name) |
| 732 | + |
| 733 | + return new_init |
602 | 734 |
|
603 | 735 | def _replace_tensor_name( |
604 | 736 | self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str |
|
0 commit comments