| 
8 | 8 | import copy  | 
9 | 9 | import logging  | 
10 | 10 | from contextlib import contextmanager, nullcontext  | 
 | 11 | +from dataclasses import dataclass  | 
11 | 12 | from functools import singledispatch  | 
12 |  | -from typing import Generator, List  | 
 | 13 | +from typing import Dict, Generator, List  | 
13 | 14 | 
 
  | 
14 | 15 | import torch  | 
15 | 16 | 
 
  | 
@@ -417,3 +418,373 @@ def to_backend(  | 
417 | 418 |         constants=tagged_exported_program.constants,  | 
418 | 419 |         verifiers=[tagged_exported_program.verifier],  | 
419 | 420 |     )  | 
 | 421 | + | 
 | 422 | + | 
 | 423 | +def _create_partitions_in_graph_module(  | 
 | 424 | +    tagged_graph_module: torch.fx.GraphModule,  | 
 | 425 | +    partition_result: PartitionResult,  | 
 | 426 | +    owning_program: ExportedProgram,  | 
 | 427 | +    is_submodule: bool,  | 
 | 428 | +) -> Dict[str, List[torch.fx.Node]]:  | 
 | 429 | +    backend_id_to_submodule_name = {}  | 
 | 430 | +    for tag, delegation_spec in partition_result.partition_tags.items():  | 
 | 431 | +        # Create partition with nodes containing this tag. There should only be  | 
 | 432 | +        # one contained submodule per tag  | 
 | 433 | +        node_list = _get_node_list_with_same_tag(  | 
 | 434 | +            tagged_graph_module, tag, owning_program  | 
 | 435 | +        )  | 
 | 436 | + | 
 | 437 | +        if len(node_list) == 0:  | 
 | 438 | +            logging.debug(f"Did not find any nodes for tag {tag}")  | 
 | 439 | +            continue  | 
 | 440 | + | 
 | 441 | +        logging.debug(f"For tag {tag}, found nodes {node_list}")  | 
 | 442 | +        # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)  | 
 | 443 | + | 
 | 444 | +        replace_ctx = (  | 
 | 445 | +            tagged_graph_module._set_replace_hook(  | 
 | 446 | +                owning_program.graph_signature.get_replace_hook()  | 
 | 447 | +            )  | 
 | 448 | +            if not is_submodule  | 
 | 449 | +            else nullcontext()  | 
 | 450 | +        )  | 
 | 451 | +        with replace_ctx:  | 
 | 452 | +            submodule, call_module_node = create_submodule_from_nodes(  | 
 | 453 | +                tagged_graph_module, node_list, tag  | 
 | 454 | +            )  | 
 | 455 | + | 
 | 456 | +        tagged_graph_module_output_node = [  | 
 | 457 | +            node for node in tagged_graph_module.graph.nodes if node.op == "output"  | 
 | 458 | +        ][0]  | 
 | 459 | +        submodule_output_node = [  | 
 | 460 | +            node for node in submodule.graph.nodes if node.op == "output"  | 
 | 461 | +        ][0]  | 
 | 462 | +        # Copy the output node meta from the original output node, because  | 
 | 463 | +        # create_submodule_from_nodes doesn't cover the meta field  | 
 | 464 | +        submodule_output_node.meta = tagged_graph_module_output_node.meta  | 
 | 465 | +        logging.debug(f"Partitioned graph module: {tagged_graph_module}")  | 
 | 466 | +        (  | 
 | 467 | +            submodule_program,  | 
 | 468 | +            toplevel_input_specs_to_delete,  | 
 | 469 | +            toplevel_output_specs_to_delete,  | 
 | 470 | +        ) = create_exported_program_from_submodule(  | 
 | 471 | +            submodule,  | 
 | 472 | +            owning_program,  | 
 | 473 | +            tag,  | 
 | 474 | +            call_module_node,  | 
 | 475 | +            is_submodule,  | 
 | 476 | +        )  | 
 | 477 | +        call_module_node.meta["backend_id"] = delegation_spec.backend_id  | 
 | 478 | +        call_module_node.meta["compile_spec"] = delegation_spec.compile_specs  | 
 | 479 | +        call_module_node.meta["submodule_program"] = submodule_program  | 
 | 480 | +        call_module_node.meta["toplevel_input_specs_to_delete"] = (  | 
 | 481 | +            toplevel_input_specs_to_delete  | 
 | 482 | +        )  | 
 | 483 | +        call_module_node.meta["toplevel_output_specs_to_delete"] = (  | 
 | 484 | +            toplevel_output_specs_to_delete  | 
 | 485 | +        )  | 
 | 486 | +        call_module_node.meta["is_submodule"] = is_submodule  | 
 | 487 | + | 
 | 488 | +        if delegation_spec.backend_id not in backend_id_to_submodule_name:  | 
 | 489 | +            backend_id_to_submodule_name[delegation_spec.backend_id] = []  | 
 | 490 | + | 
 | 491 | +        # The call_module_node created here might not be the same node instance as  | 
 | 492 | +        # the one in the final graph module. This is because this node might be replaced  | 
 | 493 | +        # in future edits to the graph. As a result, we just keep track of the node's name  | 
 | 494 | +        # and at the end we search for this node in our final graph module  | 
 | 495 | +        backend_id_to_submodule_name[delegation_spec.backend_id].append(  | 
 | 496 | +            call_module_node.target  | 
 | 497 | +        )  | 
 | 498 | + | 
 | 499 | +    created_submodule_nodes = dict(  | 
 | 500 | +        (key, []) for key in backend_id_to_submodule_name.keys()  | 
 | 501 | +    )  | 
 | 502 | +    for backend_id, submodule_name in backend_id_to_submodule_name.items():  | 
 | 503 | +        for node in tagged_graph_module.graph.nodes:  | 
 | 504 | +            if node.op == "call_module" and node.target in submodule_name:  | 
 | 505 | +                created_submodule_nodes[backend_id].append(node)  | 
 | 506 | + | 
 | 507 | +    # check the number of submodule_names and submodule_nodes are equal  | 
 | 508 | +    for backend_id in created_submodule_nodes.keys():  | 
 | 509 | +        assert len(created_submodule_nodes[backend_id]) == len(  | 
 | 510 | +            backend_id_to_submodule_name[backend_id]  | 
 | 511 | +        )  | 
 | 512 | + | 
 | 513 | +    return created_submodule_nodes  | 
 | 514 | + | 
 | 515 | + | 
 | 516 | +def _create_partitions(  | 
 | 517 | +    tagged_graph_module: torch.fx.GraphModule,  | 
 | 518 | +    partition_result: PartitionResult,  | 
 | 519 | +    owning_program: ExportedProgram,  | 
 | 520 | +    is_submodule: bool = False,  | 
 | 521 | +) -> Dict[str, List[torch.fx.Node]]:  | 
 | 522 | +    backend_id_to_call_submodules = _create_partitions_in_graph_module(  | 
 | 523 | +        tagged_graph_module, partition_result, owning_program, is_submodule  | 
 | 524 | +    )  | 
 | 525 | + | 
 | 526 | +    # Recursively partition and lower for submodules  | 
 | 527 | +    for _, submod, _ in get_control_flow_submodules(tagged_graph_module):  | 
 | 528 | +        nested_backend_id_to_call_submodules = _create_partitions(  | 
 | 529 | +            submod, partition_result, owning_program, is_submodule=True  | 
 | 530 | +        )  | 
 | 531 | +        for (  | 
 | 532 | +            backend_id,  | 
 | 533 | +            nested_submodules,  | 
 | 534 | +        ) in nested_backend_id_to_call_submodules.items():  | 
 | 535 | +            if backend_id not in backend_id_to_call_submodules:  | 
 | 536 | +                backend_id_to_call_submodules[backend_id] = nested_submodules  | 
 | 537 | +            else:  | 
 | 538 | +                backend_id_to_call_submodules[backend_id].extend(nested_submodules)  | 
 | 539 | + | 
 | 540 | +    return backend_id_to_call_submodules  | 
 | 541 | + | 
 | 542 | + | 
 | 543 | +def lower_all_submodules_to_backend(  | 
 | 544 | +    backend_id: str,  | 
 | 545 | +    method_to_submodules_nodes: Dict[str, List[torch.fx.Node]],  | 
 | 546 | +    method_to_tagged_edge_program: Dict[str, ExportedProgram],  | 
 | 547 | +) -> None:  | 
 | 548 | +    """  | 
 | 549 | +    Lower all submodules nodes given in the method_to_submodule_nodes map to backend_id.  | 
 | 550 | +    """  | 
 | 551 | +    # The created exported program for the submodules are in the call_module node's meta data  | 
 | 552 | +    # We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs  | 
 | 553 | +    method_to_partitioned_program = {  | 
 | 554 | +        method_name: [node.meta["submodule_program"] for node in call_submodule_nodes]  | 
 | 555 | +        for method_name, call_submodule_nodes in method_to_submodules_nodes.items()  | 
 | 556 | +    }  | 
 | 557 | +    method_to_compile_specs = {  | 
 | 558 | +        method_name: [node.meta["compile_spec"] for node in call_submodule_nodes]  | 
 | 559 | +        for method_name, call_submodule_nodes in method_to_submodules_nodes.items()  | 
 | 560 | +    }  | 
 | 561 | +    backend_found = False  | 
 | 562 | +    for cls in BackendDetails.__subclasses__():  | 
 | 563 | +        if backend_id == cls.__name__:  | 
 | 564 | +            method_to_preprocess_result: dict[str, List[PreprocessResult]] = (  | 
 | 565 | +                cls.preprocess_multimethod(  | 
 | 566 | +                    method_to_partitioned_program, method_to_compile_specs  | 
 | 567 | +                )  | 
 | 568 | +            )  | 
 | 569 | +            backend_found = True  | 
 | 570 | + | 
 | 571 | +    if not backend_found:  | 
 | 572 | +        raise NotImplementedError(f"Backend {backend_id} was not found.")  | 
 | 573 | + | 
 | 574 | +    for method_name in method_to_preprocess_result.keys():  | 
 | 575 | +        owning_program = method_to_tagged_edge_program[method_name]  | 
 | 576 | +        list_of_preprocess_results = method_to_preprocess_result[method_name]  | 
 | 577 | +        list_of_call_submodule_nodes = method_to_submodules_nodes[method_name]  | 
 | 578 | +        list_of_compile_specs = method_to_compile_specs[method_name]  | 
 | 579 | +        assert (  | 
 | 580 | +            len(list_of_preprocess_results) == len(list_of_call_submodule_nodes),  | 
 | 581 | +            f"Expected {len(list_of_call_submodule_nodes)} preprocessed results for method {method_name} but got {len(list_of_preprocess_results)}",  | 
 | 582 | +        )  | 
 | 583 | +        for preprocess_result, call_submodule_node, compile_spec in zip(  | 
 | 584 | +            list_of_preprocess_results,  | 
 | 585 | +            list_of_call_submodule_nodes,  | 
 | 586 | +            list_of_compile_specs,  | 
 | 587 | +        ):  | 
 | 588 | +            submodule_program = call_submodule_node.meta["submodule_program"]  | 
 | 589 | +            lowered_module = LoweredBackendModule(  | 
 | 590 | +                edge_program=submodule_program,  | 
 | 591 | +                backend_id=backend_id,  | 
 | 592 | +                processed_bytes=preprocess_result.processed_bytes,  | 
 | 593 | +                compile_specs=compile_spec,  | 
 | 594 | +            )  | 
 | 595 | +            owning_graph_module = call_submodule_node.graph.owning_module  | 
 | 596 | +            is_submodule = call_submodule_node.meta["is_submodule"]  | 
 | 597 | +            toplevel_input_specs_to_delete = call_submodule_node.meta[  | 
 | 598 | +                "toplevel_input_specs_to_delete"  | 
 | 599 | +            ]  | 
 | 600 | +            toplevel_output_specs_to_delete = call_submodule_node.meta[  | 
 | 601 | +                "toplevel_output_specs_to_delete"  | 
 | 602 | +            ]  | 
 | 603 | +            # call delegate args should only use user_inputs  | 
 | 604 | +            call_delegate_args = []  | 
 | 605 | +            # Preserve input order as user_inputs  | 
 | 606 | +            for inp_name in submodule_program.graph_signature.user_inputs:  | 
 | 607 | +                for inp_node in call_submodule_node.all_input_nodes:  | 
 | 608 | +                    if inp_node.name == inp_name:  | 
 | 609 | +                        call_delegate_args.append(inp_node)  | 
 | 610 | +                        break  | 
 | 611 | + | 
 | 612 | +            def generate_debug_handle(ep: ExportedProgram) -> int:  | 
 | 613 | +                """  | 
 | 614 | +                Generate a debug handle for the given ExportedProgram.  | 
 | 615 | +                """  | 
 | 616 | +                debug_handle = 0  | 
 | 617 | +                for node in ep.graph_module.graph.nodes:  | 
 | 618 | +                    debug_handle = max(debug_handle, node.meta.get("debug_handle", 0))  | 
 | 619 | +                return debug_handle + 1  | 
 | 620 | + | 
 | 621 | +            # Replace the partitioned submodule with a lowered submodule  | 
 | 622 | +            # Add call_method node with function "forward"  | 
 | 623 | +            with owning_graph_module.graph.inserting_before(call_submodule_node):  | 
 | 624 | +                lowered_name = get_lowered_module_name(  | 
 | 625 | +                    owning_graph_module, lowered_module  | 
 | 626 | +                )  | 
 | 627 | +                lowered_node = owning_graph_module.graph.get_attr(lowered_name)  | 
 | 628 | +                call_delegate_node = owning_graph_module.graph.call_function(  | 
 | 629 | +                    executorch_call_delegate,  | 
 | 630 | +                    (lowered_node,) + tuple(call_delegate_args),  | 
 | 631 | +                    call_submodule_node.kwargs,  | 
 | 632 | +                )  | 
 | 633 | +                call_delegate_node.meta["debug_handle"] = generate_debug_handle(  | 
 | 634 | +                    owning_program  | 
 | 635 | +                )  | 
 | 636 | +                call_delegate_node.meta["val"] = call_submodule_node.meta["val"]  | 
 | 637 | +                call_submodule_node.replace_all_uses_with(call_delegate_node)  | 
 | 638 | +                owning_graph_module.graph.erase_node(call_submodule_node)  | 
 | 639 | + | 
 | 640 | +            if is_submodule:  | 
 | 641 | +                assert len(toplevel_input_specs_to_delete) == 0  | 
 | 642 | +                assert len(toplevel_output_specs_to_delete) == 0  | 
 | 643 | +            elif (  | 
 | 644 | +                len(toplevel_input_specs_to_delete) > 0  | 
 | 645 | +                or len(toplevel_output_specs_to_delete) > 0  | 
 | 646 | +            ):  | 
 | 647 | +                _unsafe_adjust_original_program(  | 
 | 648 | +                    owning_program,  | 
 | 649 | +                    call_delegate_node,  | 
 | 650 | +                    toplevel_input_specs_to_delete,  | 
 | 651 | +                    toplevel_output_specs_to_delete,  | 
 | 652 | +                )  | 
 | 653 | + | 
 | 654 | + | 
 | 655 | +@dataclass  | 
 | 656 | +class MethodProgramsPartitionerSpec:  | 
 | 657 | +    """  | 
 | 658 | +    Since single dispatch for to_backend requires the first argument to be a  | 
 | 659 | +    valid class, we create the following dataclass spec to hold the dictionaries  | 
 | 660 | +    mapping the method name to the corresponding program, partitioner  | 
 | 661 | +    """  | 
 | 662 | + | 
 | 663 | +    method_to_edge_program: Dict[str, ExportedProgram]  | 
 | 664 | +    method_to_partitioner: Dict[str, Partitioner]  | 
 | 665 | + | 
 | 666 | + | 
 | 667 | +@to_backend.register  | 
 | 668 | +def _(  | 
 | 669 | +    method_edge_program_partitioners: MethodProgramsPartitionerSpec,  | 
 | 670 | +) -> Dict[str, ExportedProgram]:  | 
 | 671 | +    """  | 
 | 672 | +    Add overloaded implementations for to_backend:  | 
 | 673 | +
  | 
 | 674 | +    ::  | 
 | 675 | +
  | 
 | 676 | +     def to_backend(  | 
 | 677 | +        method_edge_program_partitioners: MethodProgramsPartitionerSpec  | 
 | 678 | +    ) -> Dict[str, ExportedProgram]:  | 
 | 679 | +
  | 
 | 680 | +    Returns a semantically-equivalent dictionary of programs to the programs given as input (represented  | 
 | 681 | +    as a graph module in Edge dialect), but with portions of the program targeted for  | 
 | 682 | +    delegation as determined by the partitioner.  | 
 | 683 | +
  | 
 | 684 | +    Args:  | 
 | 685 | +        method_edge_program_partitioners: contains two mappings,  | 
 | 686 | +        - method_to_edge_program: mapping of method names to their respective programs in Edge dialect.  | 
 | 687 | +        - method_to_partitioner: mapping of method names to an instance of the partitioner, in charge with tagging  | 
 | 688 | +        portions of the specified program for delegation. A valid partitioner must return PartitionerResult  | 
 | 689 | +        including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and  | 
 | 690 | +        the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec.  | 
 | 691 | +
  | 
 | 692 | +
  | 
 | 693 | +    Returns:  | 
 | 694 | +        ExportedProgram: The input program, with some portions targeted for delegation.  | 
 | 695 | +    """  | 
 | 696 | +    method_to_edge_program = method_edge_program_partitioners.method_to_edge_program  | 
 | 697 | +    method_to_partitioner = method_edge_program_partitioners.method_to_partitioner  | 
 | 698 | + | 
 | 699 | +    partitioned_and_lowered_exported_programs = {}  | 
 | 700 | +    backend_id_to_method_submodules_map = {}  | 
 | 701 | +    method_to_tagged_exported_program = {}  | 
 | 702 | + | 
 | 703 | +    for method_name, partitioner_instance in method_to_partitioner.items():  | 
 | 704 | +        assert (  | 
 | 705 | +            method_name in method_to_edge_program  | 
 | 706 | +        ), f"Partitioner for method {method_name} is not provided"  | 
 | 707 | +        edge_program = method_to_edge_program[method_name]  | 
 | 708 | +        edge_program._validate()  | 
 | 709 | + | 
 | 710 | +        # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.  | 
 | 711 | +        # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.  | 
 | 712 | +        try:  | 
 | 713 | +            fake_edge_program = get_fake_program(edge_program)  | 
 | 714 | +        except Exception as e:  | 
 | 715 | +            logging.warning(  | 
 | 716 | +                f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}"  | 
 | 717 | +            )  | 
 | 718 | +            fake_edge_program = copy.deepcopy(edge_program)  | 
 | 719 | +        partitioner_result = partitioner_instance(fake_edge_program)  | 
 | 720 | +        tagged_exported_program = partitioner_result.tagged_exported_program  | 
 | 721 | +        method_to_tagged_exported_program[method_name] = tagged_exported_program  | 
 | 722 | + | 
 | 723 | +        # Check that the partitioner did not modify the original graph  | 
 | 724 | +        if _ENABLE_VALIDATION:  | 
 | 725 | +            assert is_identical_graph(  | 
 | 726 | +                tagged_exported_program.graph_module,  | 
 | 727 | +                edge_program.graph_module,  | 
 | 728 | +            ), f"The partitioner {partitioner_instance} should not modify the graph module"  | 
 | 729 | +        else:  | 
 | 730 | +            logging.warning("Disabled validating the partitioner.")  | 
 | 731 | + | 
 | 732 | +        assert (  | 
 | 733 | +            partitioner_result.partition_tags is not None  | 
 | 734 | +        ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"  | 
 | 735 | + | 
 | 736 | +        update_to_real_program(tagged_exported_program, edge_program)  | 
 | 737 | + | 
 | 738 | +        for tag, _ in partitioner_result.partition_tags.items():  | 
 | 739 | +            _maybe_duplicate_constant_nodes(tagged_exported_program, tag)  | 
 | 740 | + | 
 | 741 | +        backend_id_to_call_submodule_nodes = _create_partitions(  | 
 | 742 | +            tagged_exported_program.graph_module,  | 
 | 743 | +            partitioner_result,  | 
 | 744 | +            tagged_exported_program,  | 
 | 745 | +        )  | 
 | 746 | +        for (  | 
 | 747 | +            backend_id,  | 
 | 748 | +            call_submodule_nodes,  | 
 | 749 | +        ) in backend_id_to_call_submodule_nodes.items():  | 
 | 750 | +            if backend_id not in backend_id_to_method_submodules_map:  | 
 | 751 | +                backend_id_to_method_submodules_map[backend_id] = {}  | 
 | 752 | +            backend_id_to_method_submodules_map[backend_id][  | 
 | 753 | +                method_name  | 
 | 754 | +            ] = call_submodule_nodes  | 
 | 755 | + | 
 | 756 | +    for (  | 
 | 757 | +        backend_id,  | 
 | 758 | +        method_to_submodule_nodes,  | 
 | 759 | +    ) in backend_id_to_method_submodules_map.items():  | 
 | 760 | +        lower_all_submodules_to_backend(  | 
 | 761 | +            backend_id,  | 
 | 762 | +            method_to_submodule_nodes,  | 
 | 763 | +            method_to_tagged_exported_program,  | 
 | 764 | +        )  | 
 | 765 | + | 
 | 766 | +    for method_name in method_to_edge_program.keys():  | 
 | 767 | +        if method_name in method_to_tagged_exported_program:  | 
 | 768 | +            tagged_exported_program = method_to_tagged_exported_program[method_name]  | 
 | 769 | +            partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(  | 
 | 770 | +                root=tagged_exported_program.graph_module,  | 
 | 771 | +                graph=tagged_exported_program.graph_module.graph,  | 
 | 772 | +                graph_signature=tagged_exported_program.graph_signature,  | 
 | 773 | +                state_dict=tagged_exported_program.state_dict,  | 
 | 774 | +                range_constraints=copy.deepcopy(  | 
 | 775 | +                    tagged_exported_program.range_constraints  | 
 | 776 | +                ),  | 
 | 777 | +                module_call_graph=copy.deepcopy(  | 
 | 778 | +                    tagged_exported_program.module_call_graph  | 
 | 779 | +                ),  | 
 | 780 | +                example_inputs=None,  | 
 | 781 | +                constants=tagged_exported_program.constants,  | 
 | 782 | +                verifiers=[tagged_exported_program.verifier],  | 
 | 783 | +            )  | 
 | 784 | +        else:  | 
 | 785 | +            # this edge program wasn't partitioned, so we can just return it as is  | 
 | 786 | +            partitioned_and_lowered_exported_programs[method_name] = (  | 
 | 787 | +                method_to_edge_program[method_name]  | 
 | 788 | +            )  | 
 | 789 | + | 
 | 790 | +    return partitioned_and_lowered_exported_programs  | 
0 commit comments