44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- # pyre-unsafe
7+ # pyre-strict
88
99import logging
1010import math
11- import typing
1211from collections import defaultdict
1312from dataclasses import dataclass
1413from typing import Callable , cast , DefaultDict , Iterable , Optional , Sequence , TypeAlias
2827
2928
3029@dataclass (frozen = True )
31- class SourceInfo :
30+ class RelativePlacementConstraint :
3231 """Information of source node and offset used for views."""
3332
3433 source : torch .fx .Node
3534 offset : int = 0
3635
3736
37+ @dataclass (frozen = True )
38+ class AbsolutePlacementConstraint :
39+ """Information on placement constraint memory id and offset."""
40+
41+ pinned_memory_id : int
42+
43+ # If offset is None, then the tensor can be placed anywhere in the memory id.
44+ offset : Optional [int ] = None
45+
46+
3847class MemConstraints :
3948 """
4049 This class contains all the tensor placement constraints that we create
4150 during memory planning.
42- Any tensor whose placement is derived off another tensor via a constraint
43- is not included in memory planning, and is marked as skipped.
51+
52+ We have two types of placement constraints:
53+ 1. Relative placement constraints: These are constraints that specify the
54+ relative placement of a tensor with respect to another tensor. For
55+ example, when slice dim is 0, slice output can be placed relative to
56+ their inputs and the op can be replaced with a nop.
57+ 2. Absolute placement constraints: These are constraints that specify the
58+ absolute placement of a tensor either in a specific memory id, or both
59+ a specific memory id and offset. For example, for operators that require
60+ a specific memory id + offset for we can use this constraint to specify
61+ location of inputs/outputs or even temporary buffers.
4462 """
4563
4664 def __init__ (
@@ -62,29 +80,38 @@ def __init__(
6280 # A set of tensor spec ids that must be skipped during memory allocation.
6381 # The exact mem_id and offset of the skipped tensors will be computed from
6482 # the constraints.
65- self ._source_node : dict [int , SourceInfo ] = {}
83+ self ._relative_placement_constraint : dict [int , RelativePlacementConstraint ] = {}
6684
6785 # A map from `id(TensorSpec)` to a set of mem_ids that cannot be used for
6886 # allocating the tensor.
6987 self ._mem_id_blocklist : dict [int , set [int ]] = {}
7088
71- def get_source_info (self , node : torch .fx .Node ) -> Optional [SourceInfo ]:
89+ # A map from `id(TensorSpec)` to a AbsolutePlacementConstraint that specifies mem_id and optionally exact offset.
90+ self ._absolute_placement_constraints : dict [int , AbsolutePlacementConstraint ] = (
91+ {}
92+ )
93+
94+ def get_relative_placement_source (
95+ self , node : torch .fx .Node
96+ ) -> Optional [RelativePlacementConstraint ]:
7297 spec = node .meta .get ("spec" )
7398 spec_id = id (spec )
74- if spec_id not in self ._source_node :
99+ if spec_id not in self ._relative_placement_constraint :
75100 return None
76- return self ._source_node [spec_id ]
101+ return self ._relative_placement_constraint [spec_id ]
77102
78- def set_source_info (
79- self , dependent : torch .fx .Node , source_info : SourceInfo
103+ def set_relative_placement_constraint (
104+ self ,
105+ dependent : torch .fx .Node ,
106+ placement_constraint : RelativePlacementConstraint ,
80107 ) -> None :
81108 dependent_spec = dependent .meta .get ("spec" )
82109 spec_id = id (dependent_spec )
83- self ._source_node [spec_id ] = source_info
84- if self .is_memory_planned (source_info .source ):
110+ self ._relative_placement_constraint [spec_id ] = placement_constraint
111+ if self .is_memory_planned (placement_constraint .source ):
85112 # Only add dependent nodes if source node needs memory planning.
86113 self .unresolved_loc_constraints [
87- id (source_info .source .meta .get ("spec" ))
114+ id (placement_constraint .source .meta .get ("spec" ))
88115 ].add (dependent )
89116
90117 def add_mem_id_to_blocklist (self , spec : TensorSpec , mem_id : int ) -> None :
@@ -111,7 +138,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
111138 node --> view
112139 --> relu (or some other op that can be in-place)
113140 """
114- if node_source_info := self .get_source_info (node ):
141+ if node_source_info := self .get_relative_placement_source (node ):
115142 node_spec = node .meta .get ("spec" )
116143 node_source_spec = node_source_info .source .meta .get ("spec" )
117144 return (
@@ -121,7 +148,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
121148 and self .is_alias_of (node_source_info .source , other_node )
122149 )
123150
124- if self .get_source_info (other_node ) is not None :
151+ if self .get_relative_placement_source (other_node ) is not None :
125152 return self .is_alias_of (other_node , node )
126153
127154 return node == other_node
@@ -132,14 +159,14 @@ def relative_loc_constraints_exist(self) -> bool:
132159
133160 # Return true if the spec is marked as skipped
134161 def skipped_spec (self , spec : TensorSpec ) -> bool :
135- return id (spec ) in self ._source_node
162+ return id (spec ) in self ._relative_placement_constraint
136163
137164 def is_memory_planned (
138165 self ,
139166 node : torch .fx .Node ,
140167 ) -> bool :
141168 """Return true if the node is either (1) a parameter, or (2) a placeholder."""
142- if (source_info := self .get_source_info (node )) is not None :
169+ if (source_info := self .get_relative_placement_source (node )) is not None :
143170 # If node has relative placement constraints, then check the source.
144171 return self .is_memory_planned (source_info .source )
145172 # Check if any node is a param.
@@ -183,7 +210,7 @@ def resolve_relative_loc_constraints(self, spec: TensorSpec) -> None:
183210
184211 assert isinstance (spec , TensorSpec )
185212 for dependent_node in self .unresolved_loc_constraints [spec_id ]:
186- source_info = self .get_source_info (dependent_node )
213+ source_info = self .get_relative_placement_source (dependent_node )
187214 assert source_info is not None
188215 dependent_spec = cast (TensorSpec , dependent_node .meta .get ("spec" ))
189216 dependent_spec .mem_id = spec .mem_id
@@ -202,19 +229,21 @@ def update_children_nodes(self, node: torch.fx.Node, update_lifetime: bool) -> N
202229 children_nodes = self .unresolved_loc_constraints [id (node .meta .get ("spec" ))]
203230 self .unresolved_loc_constraints .pop (id (node .meta .get ("spec" )))
204231
205- source_info = self .get_source_info (node )
232+ source_info = self .get_relative_placement_source (node )
206233 assert source_info is not None
207234
208235 for child_node in children_nodes :
209- child_info = self ._source_node .pop (id (child_node .meta .get ("spec" )))
210- self .generate_location_constraint (
236+ child_info = self ._relative_placement_constraint .pop (
237+ id (child_node .meta .get ("spec" ))
238+ )
239+ self .add_relative_placement_constraint (
211240 source_info .source ,
212241 child_node ,
213242 offset = source_info .offset + child_info .offset ,
214243 update_lifetime = update_lifetime ,
215244 )
216245
217- def generate_location_constraint (
246+ def add_relative_placement_constraint (
218247 self ,
219248 source : torch .fx .Node ,
220249 dependent : torch .fx .Node ,
@@ -230,29 +259,26 @@ def generate_location_constraint(
230259 logging .debug (f"Adding constraint { dependent } = { source } + { offset = } " )
231260
232261 # Assert that both source and dependent node are tensors.
233- if (info := self .get_source_info (source )) is not None :
234- return self .generate_location_constraint (
235- info .source , dependent , offset + info .offset , update_lifetime
236- )
262+ if (info := self .get_relative_placement_source (source )) is not None :
263+ source = info .source
264+ offset += info .offset
237265
238- if (info := self .get_source_info (dependent )) is not None :
266+ if (info := self .get_relative_placement_source (dependent )) is not None :
239267 # Dependent node can only be an alias (same size, offset = 0).
240268 assert self .is_alias_of (
241269 info .source , dependent
242270 ), f"Multiple constraints for allocation of { dependent } . Previous constraint: { info } new constraint: { source = } { offset = } "
243- return self .generate_location_constraint (
244- source , info .source , offset , update_lifetime = update_lifetime
245- )
271+ dependent = info .source
246272
247273 # Add the dependent spec to skip list. Its memory offset will be computed
248274 # after the output tensor is allocated space.
249- source_info = SourceInfo (source = source , offset = offset )
250- self .set_source_info (dependent , source_info )
275+ source_info = RelativePlacementConstraint (source = source , offset = offset )
276+ self .set_relative_placement_constraint (dependent , source_info )
251277
252278 # If update_lifetime is True, take a union of the lifetime of representaitve
253279 # and dependent tensors; this will become the new lifetime of source tensor.
280+ dependent_spec = dependent .meta .get ("spec" )
254281 if update_lifetime :
255- dependent_spec = dependent .meta .get ("spec" )
256282 source_spec = source .meta .get ("spec" )
257283 source .meta .get ("spec" ).lifetime = [
258284 min (source_spec .lifetime [0 ], dependent_spec .lifetime [0 ]),
@@ -261,6 +287,49 @@ def generate_location_constraint(
261287
262288 self .update_children_nodes (dependent , update_lifetime )
263289
290+ abs_constraint = self .get_absolute_placement_constraint (dependent_spec )
291+ if abs_constraint is None :
292+ return
293+
294+ # Dependent node has an absolute placement constraint.
295+ # If the offset is not 0, then we cannot add a relative placement constraint.
296+ if not self .is_alias_of (dependent , source ):
297+ raise RuntimeError (
298+ f"Cannot add relative placement constraint for { dependent } with non-zero offset { offset } when it has an absolute placement constraint { abs_constraint } "
299+ )
300+
301+ # Add the absolute placement constraint to the source node.
302+ self ._absolute_placement_constraints .pop (id (dependent_spec ))
303+ self .add_absolute_placement_constraint (
304+ source , abs_constraint .pinned_memory_id , abs_constraint .offset
305+ )
306+
307+ def add_absolute_placement_constraint (
308+ self , node : torch .fx .Node , pinned_memory_id : int , offset : Optional [int ] = None
309+ ) -> None :
310+ """Add a memory pinning constraint for `node` to `mem_id`."""
311+ logging .debug (
312+ f"Adding memory pinning constraint { node = } = { pinned_memory_id = } at { offset = } "
313+ )
314+ source_node : torch .fx .Node = node
315+ if (info := self .get_relative_placement_source (node )) is not None :
316+ assert self .is_alias_of (info .source , node )
317+ logging .debug (
318+ f"Setting { node } to { info .source } + { offset = } . Pinned to { pinned_memory_id = } "
319+ )
320+ source_node = info .source
321+ self ._absolute_placement_constraints [id (source_node .meta .get ("spec" ))] = (
322+ AbsolutePlacementConstraint (
323+ pinned_memory_id = pinned_memory_id , offset = offset
324+ )
325+ )
326+
327+ def get_absolute_placement_constraint (
328+ self , spec : TensorSpec
329+ ) -> Optional [AbsolutePlacementConstraint ]:
330+ """Return true if `node` has an absolute placement constraint."""
331+ return self ._absolute_placement_constraints .get (id (spec ), None )
332+
264333
265334def get_relative_offsets_of_cat_tensors (
266335 cat_tensors : Sequence [torch .fx .Node ],
@@ -342,7 +411,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
342411
343412 def is_slice_view (self , node : torch .fx .Node ) -> bool :
344413 """Return if `node` has constraints and is not an alias of another node."""
345- if (source_info := self .constraint .get_source_info (node )) is not None :
414+ if (
415+ source_info := self .constraint .get_relative_placement_source (node )
416+ ) is not None :
346417 return not self .constraint .is_alias_of (source_info .source , node )
347418 return False
348419
@@ -426,7 +497,9 @@ def is_removable_cat_op(
426497 return True
427498
428499 # Currently the contiguity constraints are generated by cat operator.
429- def compute_cat_contiguity_constraints (self , graph_module : torch .fx .GraphModule ):
500+ def compute_cat_contiguity_constraints (
501+ self , graph_module : torch .fx .GraphModule
502+ ) -> None :
430503 for node in graph_module .graph .nodes :
431504 # Only compute relative constraints if the cat node can be replaced with
432505 # its nop version
@@ -448,7 +521,9 @@ def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule)
448521 # Get the relative offsets for each tensor to be concatenated.
449522 relative_offsets = get_relative_offsets_of_cat_tensors (cat_tensors )
450523 for arg , offset in zip (cat_tensors , relative_offsets ):
451- self .constraint .generate_location_constraint (node , arg , offset = offset )
524+ self .constraint .add_relative_placement_constraint (
525+ node , arg , offset = offset
526+ )
452527
453528 # Update the lifetimes of the args to that of the output tensor, so
454529 # that they don't get overwritten
@@ -474,7 +549,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
474549 for node in graph_module .graph .nodes :
475550 if node .op != "call_function" or node .target != memory .view :
476551 continue
477- self .constraint .generate_location_constraint (node .args [0 ], node )
552+ self .constraint .add_relative_placement_constraint (node .args [0 ], node )
478553
479554
480555@register_cadence_pass (CadencePassAttribute (opt_level = 2 ))
@@ -544,7 +619,7 @@ def removable_slice_or_select_op(
544619 # the input and output tensor.
545620 def compute_slice_and_select_loc_constraints (
546621 self , graph_module : torch .fx .GraphModule
547- ):
622+ ) -> None :
548623 for node in graph_module .graph .nodes :
549624 # Only compute relative constraints if the slice node can be
550625 # replaced with its nop version
@@ -563,7 +638,7 @@ def compute_slice_and_select_loc_constraints(
563638 # And now generate location constraint between input and output
564639 # tensors of slice node
565640 arg = node .args [0 ]
566- self .constraint .generate_location_constraint (
641+ self .constraint .add_relative_placement_constraint (
567642 arg ,
568643 node ,
569644 offset = offset ,
@@ -607,12 +682,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
607682 filtered_passes = [
608683 mcg_pass (self .mem_constraints )
609684 for mcg_pass in cast (
610- list [
611- typing .Callable [
612- [MemConstraints ],
613- typing .Callable [[torch .fx .GraphModule ], Optional [PassResult ]],
614- ]
615- ],
685+ list [ConstraintsGenPass ],
616686 # pyre-ignore[6]: Incompatible parameter type.
617687 list (filter (pass_filter , constraint_gen_passes )),
618688 )
0 commit comments