Skip to content

Commit ba4bb54

Browse files
committed
[ET-VK] Refine paritioner to account for storage type and memory layout
## Context There are a variety of ways that tensors can be represented in Vulkan. The two main descriptors for how a tensor is laid out in memory is: 1. Storage Type (buffer or texture) 2. Memory Layout (which dim is packed along a texel, which dim has a stride of 1, etc.) Due to the differences between buffers and textures, and the differences between different memory layouts, an implementation for an operator may only support a specific set of (storage type, memory layout) combinations. Furthermore, if an operator implementation supports multiple (storage type, memory layout) combinations, there may be a "preferred" setting which results in optimal performance. These changes lay the foundation for the implementation of a memory metadata tagging graph transform, which will make sure that all tensors participating in an operator call is has a valid/optimal (storage type, memory layout) setting, and insert transition operators to transfer input tensors to the correct memory settings when necessary. An additional change that is required arises from the fact that in Vulkan, there is a limit on texture and buffer sizes. Therefore, the partitioner needs to account for the storage types and memory layouts supported by the operator implementation, and check if all tensors participating in a computation can be represented with some storage type, memory layout combination supported by the implementation. ## Changes Improvements to the operator registry: * Introduce utility functions to check the optimal and enabled storage types and memory layouts for an operator Improvements to the Partitioner: * Account for the storage types and memory layouts supported by an operator when deciding if a node should be partitioned * Improved logic for fusable ops (i.e. the permute/transpose before a mm which can be fused into linear) to check if the final target op is supported in Vulkan, and only partition those nodes if so. Otherwise, don't partition it so that it can be fused by another backend. Differential Revision: [D65428843](https://our.internmc.facebook.com/intern/diff/D65428843/) [ghstack-poisoned]
1 parent 09cf982 commit ba4bb54

File tree

5 files changed

+416
-106
lines changed

5 files changed

+416
-106
lines changed

backends/vulkan/op_registry.py

Lines changed: 133 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,65 @@
88

99
import operator
1010

11-
from typing import Callable, Dict, List, Optional, Union
11+
from typing import Callable, Dict, Optional, Set, Union
1212

1313
import executorch.backends.vulkan.custom_ops_lib # noqa
1414

1515
import torch
1616

17-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
17+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
18+
VkMemoryLayout,
19+
VkStorageType,
20+
)
21+
22+
from executorch.backends.vulkan.utils import (
23+
all_memory_layouts,
24+
all_packed_dims,
25+
PackedDim,
26+
)
1827
from executorch.exir.dialects._ops import ops as exir_ops
1928

2029
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2130
from torch._subclasses.fake_tensor import FakeTensor
2231

32+
######################
33+
## OpFeatures class ##
34+
######################
35+
2336

2437
def allow_node(node: torch.fx.Node) -> bool:
2538
return True
2639

2740

2841
class TextureImplFeatures:
2942
__slots__ = [
30-
# Indicates if the compute shader is agnostic to the packed dimension
31-
"uses_packed_dim",
32-
# Indicates if the compute shader is agnostic to the texture axis mapping
43+
"valid_packed_dims",
3344
"uses_axis_map",
34-
# Specifies a specific set of memory layouts that the shader supports. If it is
35-
# and empty list, then the supported memory layouts can be inferred from the
36-
# `uses_packed_dim` and `uses_axis_map` flags.
37-
"supported_layouts",
3845
]
3946

4047
def __init__(
4148
self,
42-
uses_packed_dim: bool = False,
4349
uses_axis_map: bool = False,
44-
supported_layouts: Optional[List[VkMemoryLayout]] = None,
50+
valid_packed_dims: Optional[Set[PackedDim]] = None,
4551
):
46-
self.uses_packed_dim: bool = uses_packed_dim
4752
self.uses_axis_map: bool = uses_axis_map
48-
self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts
53+
self.valid_packed_dims = set()
54+
if valid_packed_dims is not None:
55+
self.valid_packed_dims = valid_packed_dims
56+
57+
def valid_memory_layouts(self) -> Set[VkMemoryLayout]:
58+
layouts = set()
59+
60+
if PackedDim.WIDTH in self.valid_packed_dims:
61+
layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED)
62+
63+
if PackedDim.HEIGHT in self.valid_packed_dims:
64+
layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED)
65+
66+
if PackedDim.CHANNELS in self.valid_packed_dims:
67+
layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED)
68+
69+
return layouts
4970

5071

5172
class OpFeatures:
@@ -58,6 +79,9 @@ class OpFeatures:
5879
# bool indicating if the operator has a resize function, which allows it to
5980
# support dynamic shape tensors.
6081
"resize_fn",
82+
# Optimal
83+
"optimal_storage",
84+
"optimal_layout",
6185
# bool indicating if the operator handles its own prepacking. If this is True,
6286
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
6387
# of the op.
@@ -72,17 +96,64 @@ def __init__(
7296
texture_impl: Optional[TextureImplFeatures] = None,
7397
buffer_impl: bool = False,
7498
resize_fn: bool = False,
99+
optimal_storage: Optional[VkStorageType] = None,
100+
optimal_layout: Optional[VkMemoryLayout] = None,
75101
handles_own_prepacking: bool = False,
76102
check_node_fn: Optional[Callable] = None,
77103
):
78104
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
79105
self.buffer_impl: bool = buffer_impl
80106
self.resize_fn: bool = resize_fn
107+
self.optimal_storage: Optional[VkStorageType] = optimal_storage
108+
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
81109
self.handles_own_prepacking: bool = handles_own_prepacking
82110
self.check_node_fn: Callable = allow_node
83111
if check_node_fn is not None:
84112
self.check_node_fn = check_node_fn
85113

114+
def propose_storage_type(self) -> Optional[VkStorageType]:
115+
if self.optimal_storage is not None:
116+
return self.optimal_storage
117+
118+
if self.texture_impl is not None and not self.buffer_impl:
119+
return VkStorageType.TEXTURE_3D
120+
elif self.buffer_impl and self.texture_impl is None:
121+
return VkStorageType.BUFFER
122+
123+
return None
124+
125+
def supported_storage_types(self) -> Set[VkStorageType]:
126+
storage_types = set()
127+
if self.texture_impl is not None:
128+
storage_types.add(VkStorageType.TEXTURE_3D)
129+
if self.buffer_impl:
130+
storage_types.add(VkStorageType.BUFFER)
131+
132+
return storage_types
133+
134+
def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]:
135+
if self.optimal_layout is not None:
136+
return self.optimal_layout
137+
138+
if storage == VkStorageType.TEXTURE_3D:
139+
assert self.texture_impl is not None
140+
possible_layouts = self.texture_impl.valid_memory_layouts()
141+
if len(possible_layouts) == 1:
142+
return next(iter(possible_layouts))
143+
144+
return None
145+
146+
def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]:
147+
if storage == VkStorageType.TEXTURE_3D:
148+
assert self.texture_impl is not None
149+
return self.texture_impl.valid_memory_layouts()
150+
else:
151+
return all_memory_layouts
152+
153+
154+
#######################
155+
## Operator Registry ##
156+
#######################
86157

87158
OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload]
88159

@@ -122,8 +193,8 @@ def update_features_impl(op: OpKey):
122193
)
123194
def register_ephemeral_op(features: OpFeatures):
124195
features.texture_impl = TextureImplFeatures(
125-
uses_packed_dim=True,
126196
uses_axis_map=True,
197+
valid_packed_dims=all_packed_dims,
127198
)
128199
features.buffer_impl = True
129200
features.resize_fn = True
@@ -143,8 +214,8 @@ def register_ephemeral_op(features: OpFeatures):
143214
)
144215
def register_binary_op(features: OpFeatures):
145216
features.texture_impl = TextureImplFeatures(
146-
uses_packed_dim=True,
147217
uses_axis_map=True,
218+
valid_packed_dims=all_packed_dims,
148219
)
149220
features.resize_fn = True
150221
return features
@@ -170,8 +241,8 @@ def register_binary_op(features: OpFeatures):
170241
)
171242
def register_unary_op(features: OpFeatures):
172243
features.texture_impl = TextureImplFeatures(
173-
uses_packed_dim=True,
174244
uses_axis_map=True,
245+
valid_packed_dims=all_packed_dims,
175246
)
176247
features.buffer_impl = True
177248
features.resize_fn = True
@@ -181,8 +252,8 @@ def register_unary_op(features: OpFeatures):
181252
@update_features(exir_ops.edge.aten._to_copy.default)
182253
def register_to_copy_op(features: OpFeatures):
183254
features.texture_impl = TextureImplFeatures(
184-
uses_packed_dim=True,
185255
uses_axis_map=True,
256+
valid_packed_dims=all_packed_dims,
186257
)
187258
features.resize_fn = True
188259

@@ -220,40 +291,43 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
220291
)
221292
def register_mm_op(features: OpFeatures):
222293
features.texture_impl = TextureImplFeatures(
223-
uses_packed_dim=False,
224294
uses_axis_map=True,
225-
supported_layouts=[
226-
VkMemoryLayout.TENSOR_WIDTH_PACKED,
227-
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
228-
],
295+
valid_packed_dims={
296+
PackedDim.WIDTH,
297+
PackedDim.CHANNELS,
298+
},
229299
)
230300
features.buffer_impl = True
231301
features.resize_fn = True
302+
features.optimal_storage = VkStorageType.TEXTURE_3D
303+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
232304
features.handles_own_prepacking = True
233305
return features
234306

235307

236308
@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
237309
def register_int8_mm_op(features: OpFeatures):
238310
features.texture_impl = TextureImplFeatures(
239-
uses_packed_dim=False,
240311
uses_axis_map=False,
241-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
312+
valid_packed_dims={PackedDim.WIDTH},
242313
)
243314
features.buffer_impl = True
244315
features.resize_fn = True
316+
features.optimal_storage = VkStorageType.TEXTURE_3D
317+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
245318
features.handles_own_prepacking = True
246319
return features
247320

248321

249322
@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
250323
def register_int4_mm_op(features: OpFeatures):
251324
features.texture_impl = TextureImplFeatures(
252-
uses_packed_dim=False,
253325
uses_axis_map=False,
254-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
326+
valid_packed_dims={PackedDim.WIDTH},
255327
)
256328
features.resize_fn = True
329+
features.optimal_storage = VkStorageType.TEXTURE_3D
330+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
257331
features.handles_own_prepacking = True
258332
return features
259333

@@ -266,7 +340,7 @@ def register_int4_mm_op(features: OpFeatures):
266340
)
267341
def register_softmax_op(features: OpFeatures):
268342
features.texture_impl = TextureImplFeatures(
269-
uses_packed_dim=True,
343+
valid_packed_dims=all_packed_dims,
270344
)
271345
features.resize_fn = True
272346
return features
@@ -282,7 +356,7 @@ def register_softmax_op(features: OpFeatures):
282356
)
283357
def register_reduce_op(features: OpFeatures):
284358
features.texture_impl = TextureImplFeatures(
285-
uses_packed_dim=True,
359+
valid_packed_dims=all_packed_dims,
286360
)
287361
features.resize_fn = True
288362

@@ -309,7 +383,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
309383
)
310384
def register_2d_pool_op(features: OpFeatures):
311385
features.texture_impl = TextureImplFeatures(
312-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
386+
valid_packed_dims={PackedDim.CHANNELS},
313387
)
314388
features.resize_fn = True
315389
return features
@@ -323,27 +397,31 @@ def register_2d_pool_op(features: OpFeatures):
323397
)
324398
def register_convolution_op(features: OpFeatures):
325399
features.texture_impl = TextureImplFeatures(
326-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
400+
valid_packed_dims={PackedDim.CHANNELS},
327401
)
328402
features.resize_fn = True
403+
features.optimal_storage = VkStorageType.TEXTURE_3D
404+
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
329405
features.handles_own_prepacking = True
330406
return features
331407

332408

333409
@update_features("llama::sdpa_with_kv_cache")
334410
def register_sdpa_op(features: OpFeatures):
335411
features.texture_impl = TextureImplFeatures(
336-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
412+
valid_packed_dims={PackedDim.WIDTH},
337413
)
338414
features.resize_fn = True
415+
features.optimal_storage = VkStorageType.TEXTURE_3D
416+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
339417
features.handles_own_prepacking = True
340418
return features
341419

342420

343421
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
344422
def register_rotary_emb_op(features: OpFeatures):
345423
features.texture_impl = TextureImplFeatures(
346-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
424+
valid_packed_dims={PackedDim.WIDTH},
347425
)
348426
features.resize_fn = True
349427
return features
@@ -352,7 +430,7 @@ def register_rotary_emb_op(features: OpFeatures):
352430
@update_features(exir_ops.edge.aten.view_copy.default)
353431
def register_view_op(features: OpFeatures):
354432
features.texture_impl = TextureImplFeatures(
355-
uses_packed_dim=True,
433+
valid_packed_dims=all_packed_dims,
356434
)
357435
features.resize_fn = True
358436
return features
@@ -393,7 +471,7 @@ def register_view_op(features: OpFeatures):
393471
)
394472
def register_ported_op(features: OpFeatures):
395473
features.texture_impl = TextureImplFeatures(
396-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
474+
valid_packed_dims={PackedDim.CHANNELS},
397475
)
398476
return features
399477

@@ -408,15 +486,24 @@ def register_ported_op(features: OpFeatures):
408486
)
409487
def register_ported_ops_with_prepacking(features: OpFeatures):
410488
features.texture_impl = TextureImplFeatures(
411-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
489+
valid_packed_dims={PackedDim.CHANNELS},
412490
)
413491
features.handles_own_prepacking = True
414492
return features
415493

416494

417-
##
418-
## Utility Functions
419-
##
495+
#######################
496+
## Utility functions ##
497+
#######################
498+
499+
500+
def has_impl(target: OpKey) -> bool:
501+
if not isinstance(target, str):
502+
if target not in vulkan_supported_ops:
503+
return target.name() in vulkan_supported_ops
504+
return target in vulkan_supported_ops
505+
else:
506+
return target in vulkan_supported_ops
420507

421508

422509
def get_op_features(target: OpKey) -> OpFeatures:
@@ -430,5 +517,13 @@ def get_op_features(target: OpKey) -> OpFeatures:
430517
return vulkan_supported_ops[target]
431518

432519

520+
def optimal_storage_type(target: OpKey) -> Optional[VkStorageType]:
521+
return get_op_features(target).optimal_storage
522+
523+
524+
def optimal_memory_layout(target: OpKey) -> Optional[VkMemoryLayout]:
525+
return get_op_features(target).optimal_layout
526+
527+
433528
def handles_own_prepacking(target: OpKey) -> bool:
434529
return get_op_features(target).handles_own_prepacking

backends/vulkan/partitioner/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ runtime.python_library(
1313
],
1414
deps = [
1515
"//executorch/backends/vulkan:op_registry",
16+
"//executorch/backends/vulkan:utils_lib",
1617
"//executorch/backends/vulkan:vulkan_preprocess",
1718
"//executorch/exir:delegate",
1819
"//executorch/exir:lib",

0 commit comments

Comments
 (0)