99import json
1010import typing
1111from dataclasses import dataclass , field
12- from typing import List
12+ from typing import Any , Dict , List , Optional
1313
1414import executorch .exir .memory as memory
1515import torch
@@ -52,7 +52,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
5252 allocations at that timestep.
5353 """
5454 nodes = graph .nodes
55- memory_timeline = [None ] * len (nodes )
55+ memory_timeline : List [ Optional [ MemoryTimeline ]] = [None for _ in range ( len (nodes ))]
5656 for _ , node in enumerate (nodes ):
5757 if node .op == "output" :
5858 continue
@@ -72,11 +72,11 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
7272 stack_trace = node .meta .get ("stack_trace" )
7373 fqn = _get_module_hierarchy (node )
7474 for j in range (start , end + 1 ):
75- if memory_timeline [j ] is None :
76- # pyre-ignore
77- memory_timeline [ j ] = MemoryTimeline ()
78- # pyre-ignore
79- memory_timeline [ j ] .allocations .append (
75+ memory_timeline_j = memory_timeline [j ]
76+ if memory_timeline_j is None :
77+ memory_timeline_j = MemoryTimeline ()
78+ assert memory_timeline_j
79+ memory_timeline_j .allocations .append (
8080 Allocation (
8181 node .name ,
8282 node .target ,
@@ -87,8 +87,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
8787 stack_trace ,
8888 )
8989 )
90- # pyre-ignore
91- return memory_timeline
90+ return memory_timeline # type: ignore[return-value]
9291
9392
9493def _validate_memory_planning_is_done (exported_program : ExportedProgram ):
@@ -129,7 +128,7 @@ def generate_memory_trace(
129128
130129 memory_timeline = create_tensor_allocation_info (exported_program .graph )
131130 root = {}
132- trace_events = []
131+ trace_events : List [ Dict [ str , Any ]] = []
133132 root ["traceEvents" ] = trace_events
134133
135134 tid = 0
@@ -138,7 +137,7 @@ def generate_memory_trace(
138137 if memory_timeline_event is None :
139138 continue
140139 for allocation in memory_timeline_event .allocations :
141- e = {}
140+ e : Dict [ str , Any ] = {}
142141 e ["name" ] = allocation .name
143142 e ["cat" ] = "memory_allocation"
144143 e ["ph" ] = "X"
0 commit comments