Skip to content

Commit 5091b12

Browse files
committed
Update
[ghstack-poisoned]
1 parent 5f70823 commit 5091b12

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
116116
arg, self.buffer_limit
117117
)
118118

119+
op_available_layouts = features.supported_memory_layouts(
120+
VkStorageType.TEXTURE_3D
121+
)
122+
123+
can_use_texture = any(
124+
layout in op_available_layouts for layout in valid_texture_layouts
125+
)
126+
119127
# If there are no valid texture memory layouts, then buffer storage must be
120128
# supported by the operator implementation.
121-
if len(valid_texture_layouts) == 0:
129+
if not can_use_texture:
122130
if not can_use_buffers:
123131
return (
124132
False,
@@ -131,17 +139,8 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
131139
reason = "op requires buffers which is not supported by op impl"
132140
return compatible, reason
133141

134-
op_available_layouts = features.supported_memory_layouts(
135-
VkStorageType.TEXTURE_3D
136-
)
137142

138-
is_compatible = any(
139-
layout in op_available_layouts for layout in valid_texture_layouts
140-
)
141-
if not is_compatible:
142-
return False, "Required texutre memory layout not supported"
143-
144-
return is_compatible, "Op is compatible"
143+
return True, "Op is compatible"
145144

146145
def node_is_compatible(
147146
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
@@ -220,7 +219,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo
220219

221220
def log_skip(self, node: torch.fx.Node, reason: str) -> None:
222221
if node.op == "call_function":
223-
logger.info(
222+
print(
224223
f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}"
225224
)
226225

@@ -231,6 +230,7 @@ def is_node_supported(
231230
return r
232231

233232
def _is_node_supported(self, node: torch.fx.Node) -> bool:
233+
print("is_node_supported")
234234
target = node.target
235235
if node.target == torch.ops.higher_order.auto_functionalized:
236236
first_arg = node.args[0]
@@ -340,6 +340,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
340340
# subgraphs containing the nodes with the tags
341341
partition_tags = {}
342342

343+
logger.setLevel(logging.INFO)
344+
print("partition")
345+
print("set level but no logging...")
346+
343347
texture_limits: utils.ImageExtents = self.options.get(
344348
"texture_limits", utils.DEFAULT_TEXTURE_LIMITS
345349
)

exir/program/_program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,7 @@ def to_edge_transform_and_lower(
12571257
for name, partitioner_list in partitioner.items():
12581258
if i < len(partitioner_list):
12591259
method_to_partitioner[name] = partitioner_list[i]
1260+
print("to_backen")
12601261
edge_manager = edge_manager.to_backend(method_to_partitioner)
12611262

12621263
for name, program in edge_manager._edge_programs.items():

0 commit comments

Comments
 (0)