|
8 | 8 |
|
9 | 9 | import collections |
10 | 10 | import functools |
| 11 | +import heapq |
11 | 12 | import itertools |
12 | 13 | import logging |
13 | 14 | import operator |
@@ -707,6 +708,137 @@ def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool: |
707 | 708 | return True |
708 | 709 | return False |
709 | 710 |
|
| 711 | +def heap_optimized_greedy( |
| 712 | + graph_module: torch.fx.GraphModule, |
| 713 | + alignment: int, |
| 714 | + graph_signature: Optional[ExportGraphSignature] = None, |
| 715 | + alloc_graph_input: bool = True, |
| 716 | + alloc_graph_output: bool = True, |
| 717 | + allow_overlapping_allocations: bool = True |
| 718 | +) -> MemoryAlgoResult: |
| 719 | + """ |
| 720 | + This function implements a memory allocation strategy using a greedy approach |
| 721 | + with a priority queue (heap) to manage memory blocks. |
| 722 | + The algorithm processes each buffer specification by sorting them based on size and |
| 723 | + start time. It attempts to fit each buffer into an existing memory block without |
| 724 | + overlapping in time. If no suitable block is found, a new block is created. |
| 725 | + The format of the entries in the heap is (end_time, max_size, block_intervals, block_id), |
| 726 | + so whenever we pop an entry from the heap we're popping the block with the earliest end time. |
| 727 | + """ |
| 728 | + |
| 729 | + greedy_result = MemoryAlgoResult({}, []) |
| 730 | + |
| 731 | + extra_padded_bytes = 0 |
| 732 | + if _contains_xnnpack_delegate(graph_module): |
| 733 | + extra_padded_bytes = 64 |
| 734 | + |
| 735 | + # Don't do assertion in collect_specs_from_nodes if we have already encountered |
| 736 | + # and ignored some to_out_variant errors. |
| 737 | + do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False) |
| 738 | + |
| 739 | + specs_list = collect_specs_from_nodes( |
| 740 | + graph_module.graph.nodes, |
| 741 | + graph_signature, |
| 742 | + do_assertion=do_assertion, |
| 743 | + ignore_graph_input=not alloc_graph_input, |
| 744 | + ignore_graph_output=not alloc_graph_output, |
| 745 | + ) |
| 746 | + # Sort based on the size of the spec, and then the starting lifetime of the spec if |
| 747 | + # the size is same. |
| 748 | + specs_list = sorted(specs_list, key=lambda x: (-x.allocated_memory, x.lifetime[0])) |
| 749 | + |
| 750 | + # This is a dict where the key is the memory id and the value is the heap |
| 751 | + # for that memory id. |
| 752 | + # Format of priority queue: (end_time, max_size, block_intervals) |
| 753 | + heap_dict = defaultdict(list) |
| 754 | + # In this dict we store a mapping from memory id to the max block id that has been |
| 755 | + # assigned to that memory id. This is used to assign unique block ids to each block |
| 756 | + # in the heap. |
| 757 | + block_ids = defaultdict(int) |
| 758 | + spec_to_block_id = defaultdict(list) |
| 759 | + |
| 760 | + for spec in specs_list: |
| 761 | + spec.realign(alignment) |
| 762 | + size, start, end = spec.allocated_memory, spec.lifetime[0], spec.lifetime[1] |
| 763 | + assigned = False |
| 764 | + |
| 765 | + spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0)) |
| 766 | + if spec.mem_id is None: |
| 767 | + spec_alloc_result.mem_id = 1 |
| 768 | + else: |
| 769 | + spec_alloc_result.mem_id = spec.mem_id |
| 770 | + greedy_result.spec_dict[spec] = spec_alloc_result |
| 771 | + |
| 772 | + # Get the heap for the memory id of the spec. |
| 773 | + heap = heap_dict[spec_alloc_result.mem_id] |
| 774 | + |
| 775 | + # Check the heap for compatible blocks |
| 776 | + temp = [] |
| 777 | + while heap: |
| 778 | + block_end, block_size, block_intervals, block_id = heapq.heappop(heap) |
| 779 | + # Block can fit the buffer if: |
| 780 | + # 1. Its max_size >= buffer size |
| 781 | + # 2. No overlap with existing intervals |
| 782 | + if (block_size >= size and |
| 783 | + not any(s < end and start < e for (s, e) in block_intervals)): |
| 784 | + # Add buffer to the block |
| 785 | + block_intervals.append((start, end)) |
| 786 | + new_block_end = max(block_end, end) |
| 787 | + heapq.heappush(heap, (new_block_end, block_size, block_intervals, block_id)) |
| 788 | + # Keep track of all the specs that are assigned to this block id. |
| 789 | + spec_to_block_id[block_id] += [spec] |
| 790 | + assigned = True |
| 791 | + break |
| 792 | + else: |
| 793 | + # If the block is not compatible, add it to a temporary list so that |
| 794 | + # we can restore it to the heap later. |
| 795 | + temp.append((block_end, block_size, block_intervals, block_id)) |
| 796 | + |
| 797 | + # Restore popped blocks to the heap |
| 798 | + for item in temp: |
| 799 | + heapq.heappush(heap, item) |
| 800 | + |
| 801 | + # Create a new block if no existing block fits |
| 802 | + if not assigned: |
| 803 | + # Get max block id assigned till now for this memory id. |
| 804 | + block_id = block_ids.get(spec_alloc_result.mem_id, 0) |
| 805 | + new_block = (end, size, [(start, end)], block_id) |
| 806 | + # Add this spec to the list of specs assigned to this block id. |
| 807 | + spec_to_block_id[block_id] += [spec] |
| 808 | + # Increment the max block id assigned for this memory id. |
| 809 | + block_ids[spec_alloc_result.mem_id] += 1 |
| 810 | + heapq.heappush(heap, new_block) |
| 811 | + |
| 812 | + # Now that we have the heap for each memory id, we can assign offsets to each |
| 813 | + # spec based on the heap. |
| 814 | + # Format of priority queue: (end_time, max_size, block_intervals, block_id) |
| 815 | + if len(heap_dict) == 0: |
| 816 | + # Cannot find any tensor in the graph that needs to be allocated. |
| 817 | + # Return [0, 0] to be consistent with default behavior of naive. |
| 818 | + bufsize = [0, 0] |
| 819 | + else: |
| 820 | + bufsize = [0] * (max(heap_dict.keys()) + 1) |
| 821 | + for mem_id, heap in heap_dict.items(): |
| 822 | + input_total_size = 0 |
| 823 | + total_size = 0 |
| 824 | + if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None): |
| 825 | + # pyre-fixme[6]: For 1st argument expected |
| 826 | + # `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`. |
| 827 | + if len(bufsizes) > mem_id: |
| 828 | + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten... |
| 829 | + input_total_size = bufsizes[mem_id] |
| 830 | + while heap: |
| 831 | + block_end, block_size, block_intervals, block_id = heapq.heappop(heap) |
| 832 | + spec_list = spec_to_block_id[block_id] |
| 833 | + for spec in spec_list: |
| 834 | + spec_alloc_result = greedy_result.spec_dict.get(spec, None) |
| 835 | + assert spec_alloc_result is not None, f"Spec {spec} not found." |
| 836 | + spec_alloc_result.mem_offset = total_size |
| 837 | + total_size += block_size |
| 838 | + bufsize[mem_id] = input_total_size + total_size + extra_padded_bytes |
| 839 | + |
| 840 | + greedy_result.bufsizes = bufsize |
| 841 | + return greedy_result |
710 | 842 |
|
711 | 843 | def greedy( |
712 | 844 | graph_module: torch.fx.GraphModule, |
@@ -818,7 +950,7 @@ def memory_planning_algorithm_suite( |
818 | 950 | alloc_graph_input: bool = True, |
819 | 951 | alloc_graph_output: bool = True, |
820 | 952 | allow_overlapping_allocations: bool = True, |
821 | | - algo_list: List[Callable[..., MemoryAlgoResult]] = [greedy], |
| 953 | + algo_list: List[Callable[..., MemoryAlgoResult]] = [greedy, heap_optimized_greedy], |
822 | 954 | ) -> List[int]: |
823 | 955 | r""" |
824 | 956 | Memory planning algorithm suite that runs a list of memory planning algorithms |
|
0 commit comments