1515import torch
1616from executorch .exir import ExecutorchProgramManager
1717from executorch .exir .memory_planning import get_node_tensor_specs
18- from executorch .exir .tensor import num_bytes_from_shape_and_dtype
18+
19+ from executorch .exir .tensor import num_bytes_from_shape_and_dtype , TensorSpec
1920from torch .export import ExportedProgram
2021
2122
@@ -53,10 +54,11 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
5354 """
5455 nodes = graph .nodes
5556 memory_timeline : List [Optional [MemoryTimeline ]] = [None for _ in range (len (nodes ))]
57+ unique_specs : set [TensorSpec ] = set ()
5658 for _ , node in enumerate (nodes ):
5759 if node .op == "output" :
5860 continue
59- if node .target == memory .alloc :
61+ if node .target == memory .alloc or node . target == memory . view :
6062 continue
6163 tensor_specs = get_node_tensor_specs (node )
6264 if tensor_specs is None :
@@ -65,6 +67,9 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
6567 # TODO: Make use of mem_id in the allocation info
6668 if tensor_spec is None or tensor_spec .mem_id is None or tensor_spec .const :
6769 continue
70+ if tensor_spec in unique_specs :
71+ continue
72+ unique_specs .add (tensor_spec )
6873 start , end = tensor_spec .lifetime
6974 size = num_bytes_from_shape_and_dtype (
7075 typing .cast (torch .Size , tensor_spec .shape ), tensor_spec .dtype
@@ -75,6 +80,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
7580 memory_timeline_j = memory_timeline [j ]
7681 if memory_timeline_j is None :
7782 memory_timeline_j = MemoryTimeline ()
83+ memory_timeline [j ] = memory_timeline_j
7884 assert memory_timeline_j
7985 memory_timeline_j .allocations .append (
8086 Allocation (
@@ -106,6 +112,7 @@ def generate_memory_trace(
106112 chrome_trace_filename : str ,
107113 enable_memory_offsets : bool = False ,
108114 method_name : str = "forward" ,
115+ ommit_metadata : bool = False ,
109116):
110117 """
111118 Generate the memory timeline from the given ExecuTorch program.
@@ -151,13 +158,14 @@ def generate_memory_trace(
151158 e ["pid" ] = int (allocation .memory_id )
152159 e ["tid" ] = tid
153160 e ["args" ] = {}
154- e ["args" ]["op_name" ] = f"{ allocation .op_name } "
155- # ID refers to memory space, typically from 1 to N.
156- # For CPU, everything is allocated on one "space", other backends may have multiple.
157- e ["args" ]["Memory ID" ] = allocation .memory_id
158- e ["args" ]["fqn" ] = f"{ allocation .fqn } "
159- e ["args" ]["source" ] = f"{ allocation .file_and_line_num } "
160- e ["args" ]["bytes" ] = allocation .size_bytes
161+ if not ommit_metadata :
162+ e ["args" ]["op_name" ] = f"{ allocation .op_name } "
163+ # ID refers to memory space, typically from 1 to N.
164+ # For CPU, everything is allocated on one "space", other backends may have multiple.
165+ e ["args" ]["Memory ID" ] = allocation .memory_id
166+ e ["args" ]["fqn" ] = f"{ allocation .fqn } "
167+ e ["args" ]["source" ] = f"{ allocation .file_and_line_num } "
168+ e ["args" ]["bytes" ] = allocation .size_bytes
161169 start_time += allocation_size_kb
162170 trace_events .append (e )
163171 tid += 1
0 commit comments