88
99import operator
1010
11- from typing import Callable , Dict , List , Optional , Union
11+ from typing import Callable , Dict , Optional , Set , Union
1212
1313import executorch .backends .vulkan .custom_ops_lib # noqa
1414
1515import torch
1616
17- from executorch .backends .vulkan .serialization .vulkan_graph_schema import VkMemoryLayout
17+ from executorch .backends .vulkan .serialization .vulkan_graph_schema import (
18+ VkMemoryLayout ,
19+ VkStorageType ,
20+ )
21+
22+ from executorch .backends .vulkan .utils import (
23+ all_memory_layouts ,
24+ all_packed_dims ,
25+ PackedDim ,
26+ )
1827from executorch .exir .dialects ._ops import ops as exir_ops
1928
2029from executorch .exir .dialects .edge ._ops import EdgeOpOverload
2130from torch ._subclasses .fake_tensor import FakeTensor
2231
32+ ######################
33+ ## OpFeatures class ##
34+ ######################
35+
2336
2437def allow_node (node : torch .fx .Node ) -> bool :
2538 return True
2639
2740
2841class TextureImplFeatures :
2942 __slots__ = [
30- # Indicates if the compute shader is agnostic to the packed dimension
31- "uses_packed_dim" ,
32- # Indicates if the compute shader is agnostic to the texture axis mapping
43+ "valid_packed_dims" ,
3344 "uses_axis_map" ,
34- # Specifies a specific set of memory layouts that the shader supports. If it is
35- # and empty list, then the supported memory layouts can be inferred from the
36- # `uses_packed_dim` and `uses_axis_map` flags.
37- "supported_layouts" ,
3845 ]
3946
4047 def __init__ (
4148 self ,
42- uses_packed_dim : bool = False ,
4349 uses_axis_map : bool = False ,
44- supported_layouts : Optional [List [ VkMemoryLayout ]] = None ,
50+ valid_packed_dims : Optional [Set [ PackedDim ]] = None ,
4551 ):
46- self .uses_packed_dim : bool = uses_packed_dim
4752 self .uses_axis_map : bool = uses_axis_map
48- self .supported_layouts : Optional [List [VkMemoryLayout ]] = supported_layouts
53+ self .valid_packed_dims = set ()
54+ if valid_packed_dims is not None :
55+ self .valid_packed_dims = valid_packed_dims
56+
57+ def valid_memory_layouts (self ) -> Set [VkMemoryLayout ]:
58+ """
59+ Derive the set of memory layouts supported by the texture implementation based
60+ on the valid packed dimensions.
61+ """
62+ layouts = set ()
63+
64+ if PackedDim .WIDTH in self .valid_packed_dims :
65+ layouts .add (VkMemoryLayout .TENSOR_WIDTH_PACKED )
66+
67+ if PackedDim .HEIGHT in self .valid_packed_dims :
68+ layouts .add (VkMemoryLayout .TENSOR_HEIGHT_PACKED )
69+
70+ if PackedDim .CHANNELS in self .valid_packed_dims :
71+ layouts .add (VkMemoryLayout .TENSOR_CHANNELS_PACKED )
72+
73+ return layouts
4974
5075
5176class OpFeatures :
@@ -58,6 +83,9 @@ class OpFeatures:
5883 # bool indicating if the operator has a resize function, which allows it to
5984 # support dynamic shape tensors.
6085 "resize_fn" ,
86+ # Optimal
87+ "optimal_storage" ,
88+ "optimal_layout" ,
6189 # bool indicating if the operator handles its own prepacking. If this is True,
6290 # then the insert_prepack_nodes pass will not insert prepack nodes for the args
6391 # of the op.
@@ -72,17 +100,90 @@ def __init__(
72100 texture_impl : Optional [TextureImplFeatures ] = None ,
73101 buffer_impl : bool = False ,
74102 resize_fn : bool = False ,
103+ optimal_storage : Optional [VkStorageType ] = None ,
104+ optimal_layout : Optional [VkMemoryLayout ] = None ,
75105 handles_own_prepacking : bool = False ,
76106 check_node_fn : Optional [Callable ] = None ,
77107 ):
78108 self .texture_impl : Optional [TextureImplFeatures ] = texture_impl
79109 self .buffer_impl : bool = buffer_impl
80110 self .resize_fn : bool = resize_fn
111+ self .optimal_storage : Optional [VkStorageType ] = optimal_storage
112+ self .optimal_layout : Optional [VkMemoryLayout ] = optimal_layout
81113 self .handles_own_prepacking : bool = handles_own_prepacking
82114 self .check_node_fn : Callable = allow_node
83115 if check_node_fn is not None :
84116 self .check_node_fn = check_node_fn
85117
118+ def propose_storage_type (self ) -> Optional [VkStorageType ]:
119+ """
120+ Propose a storage type that should be used for this operator. A proposal can be
121+ made if one of the following is true:
122+ 1. The operator specifies an optimal storage type
123+ 2. Only one storage type is supported.
124+
125+ If both storage types are supported and no optimal storage type is specified,
126+ then None is returned to indicate that there is no preference in storage type.
127+ """
128+ if self .optimal_storage is not None :
129+ return self .optimal_storage
130+
131+ if self .texture_impl is not None and not self .buffer_impl :
132+ return VkStorageType .TEXTURE_3D
133+ elif self .buffer_impl and self .texture_impl is None :
134+ return VkStorageType .BUFFER
135+
136+ return None
137+
138+ def supported_storage_types (self ) -> Set [VkStorageType ]:
139+ """
140+ Return the set of storage types supported by this operator.
141+ """
142+ storage_types = set ()
143+ if self .texture_impl is not None :
144+ storage_types .add (VkStorageType .TEXTURE_3D )
145+ if self .buffer_impl :
146+ storage_types .add (VkStorageType .BUFFER )
147+
148+ return storage_types
149+
150+ def propose_memory_layout (self , storage : VkStorageType ) -> Optional [VkMemoryLayout ]:
151+ """
152+ Given a storage type as a precondition, propose a memory layout that should be
153+ used for this operator. A proposal can be made if one of the following is true:
154+ 1. The operator specifies an optimal memory layout
155+ 2. Only one memory layout is supported.
156+
157+ If multiple memory layouts are supported and no optimal memory layout is
158+ specified then return None to indicate that the "best" memory layout for the
159+ operator is ambiguous.
160+ """
161+ if self .optimal_layout is not None :
162+ return self .optimal_layout
163+
164+ if storage == VkStorageType .TEXTURE_3D :
165+ assert self .texture_impl is not None
166+ possible_layouts = self .texture_impl .valid_memory_layouts ()
167+ if len (possible_layouts ) == 1 :
168+ return next (iter (possible_layouts ))
169+
170+ return None
171+
172+ def supported_memory_layouts (self , storage : VkStorageType ) -> Set [VkMemoryLayout ]:
173+ """
174+ Return the set of memory layouts supported by this operator for a given storage
175+ type.
176+ """
177+ if storage == VkStorageType .TEXTURE_3D :
178+ assert self .texture_impl is not None
179+ return self .texture_impl .valid_memory_layouts ()
180+ else :
181+ return all_memory_layouts
182+
183+
184+ #######################
185+ ## Operator Registry ##
186+ #######################
86187
87188OpKey = Union [str , torch ._ops .OpOverload , EdgeOpOverload ]
88189
@@ -122,8 +223,8 @@ def update_features_impl(op: OpKey):
122223)
123224def register_ephemeral_op (features : OpFeatures ):
124225 features .texture_impl = TextureImplFeatures (
125- uses_packed_dim = True ,
126226 uses_axis_map = True ,
227+ valid_packed_dims = all_packed_dims ,
127228 )
128229 features .buffer_impl = True
129230 features .resize_fn = True
@@ -143,8 +244,8 @@ def register_ephemeral_op(features: OpFeatures):
143244)
144245def register_binary_op (features : OpFeatures ):
145246 features .texture_impl = TextureImplFeatures (
146- uses_packed_dim = True ,
147247 uses_axis_map = True ,
248+ valid_packed_dims = all_packed_dims ,
148249 )
149250 features .resize_fn = True
150251 return features
@@ -170,8 +271,8 @@ def register_binary_op(features: OpFeatures):
170271)
171272def register_unary_op (features : OpFeatures ):
172273 features .texture_impl = TextureImplFeatures (
173- uses_packed_dim = True ,
174274 uses_axis_map = True ,
275+ valid_packed_dims = all_packed_dims ,
175276 )
176277 features .buffer_impl = True
177278 features .resize_fn = True
@@ -181,8 +282,8 @@ def register_unary_op(features: OpFeatures):
181282@update_features (exir_ops .edge .aten ._to_copy .default )
182283def register_to_copy_op (features : OpFeatures ):
183284 features .texture_impl = TextureImplFeatures (
184- uses_packed_dim = True ,
185285 uses_axis_map = True ,
286+ valid_packed_dims = all_packed_dims ,
186287 )
187288 features .resize_fn = True
188289
@@ -220,40 +321,43 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
220321)
221322def register_mm_op (features : OpFeatures ):
222323 features .texture_impl = TextureImplFeatures (
223- uses_packed_dim = False ,
224324 uses_axis_map = True ,
225- supported_layouts = [
226- VkMemoryLayout . TENSOR_WIDTH_PACKED ,
227- VkMemoryLayout . TENSOR_CHANNELS_PACKED ,
228- ] ,
325+ valid_packed_dims = {
326+ PackedDim . WIDTH ,
327+ PackedDim . CHANNELS ,
328+ } ,
229329 )
230330 features .buffer_impl = True
231331 features .resize_fn = True
332+ features .optimal_storage = VkStorageType .TEXTURE_3D
333+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
232334 features .handles_own_prepacking = True
233335 return features
234336
235337
236338@update_features (exir_ops .edge .aten ._weight_int8pack_mm .default )
237339def register_int8_mm_op (features : OpFeatures ):
238340 features .texture_impl = TextureImplFeatures (
239- uses_packed_dim = False ,
240341 uses_axis_map = False ,
241- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
342+ valid_packed_dims = { PackedDim . WIDTH } ,
242343 )
243344 features .buffer_impl = True
244345 features .resize_fn = True
346+ features .optimal_storage = VkStorageType .TEXTURE_3D
347+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
245348 features .handles_own_prepacking = True
246349 return features
247350
248351
249352@update_features (exir_ops .edge .et_vk .linear_weight_int4 .default )
250353def register_int4_mm_op (features : OpFeatures ):
251354 features .texture_impl = TextureImplFeatures (
252- uses_packed_dim = False ,
253355 uses_axis_map = False ,
254- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
356+ valid_packed_dims = { PackedDim . WIDTH } ,
255357 )
256358 features .resize_fn = True
359+ features .optimal_storage = VkStorageType .TEXTURE_3D
360+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
257361 features .handles_own_prepacking = True
258362 return features
259363
@@ -266,7 +370,7 @@ def register_int4_mm_op(features: OpFeatures):
266370)
267371def register_softmax_op (features : OpFeatures ):
268372 features .texture_impl = TextureImplFeatures (
269- uses_packed_dim = True ,
373+ valid_packed_dims = all_packed_dims ,
270374 )
271375 features .resize_fn = True
272376 return features
@@ -282,7 +386,7 @@ def register_softmax_op(features: OpFeatures):
282386)
283387def register_reduce_op (features : OpFeatures ):
284388 features .texture_impl = TextureImplFeatures (
285- uses_packed_dim = True ,
389+ valid_packed_dims = all_packed_dims ,
286390 )
287391 features .resize_fn = True
288392
@@ -309,7 +413,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
309413)
310414def register_2d_pool_op (features : OpFeatures ):
311415 features .texture_impl = TextureImplFeatures (
312- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
416+ valid_packed_dims = { PackedDim . CHANNELS } ,
313417 )
314418 features .resize_fn = True
315419 return features
@@ -323,27 +427,31 @@ def register_2d_pool_op(features: OpFeatures):
323427)
324428def register_convolution_op (features : OpFeatures ):
325429 features .texture_impl = TextureImplFeatures (
326- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
430+ valid_packed_dims = { PackedDim . CHANNELS } ,
327431 )
328432 features .resize_fn = True
433+ features .optimal_storage = VkStorageType .TEXTURE_3D
434+ features .optimal_layout = VkMemoryLayout .TENSOR_CHANNELS_PACKED
329435 features .handles_own_prepacking = True
330436 return features
331437
332438
333439@update_features ("llama::sdpa_with_kv_cache" )
334440def register_sdpa_op (features : OpFeatures ):
335441 features .texture_impl = TextureImplFeatures (
336- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
442+ valid_packed_dims = { PackedDim . WIDTH } ,
337443 )
338444 features .resize_fn = True
445+ features .optimal_storage = VkStorageType .TEXTURE_3D
446+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
339447 features .handles_own_prepacking = True
340448 return features
341449
342450
343451@update_features (exir_ops .edge .et_vk .apply_rotary_emb .default )
344452def register_rotary_emb_op (features : OpFeatures ):
345453 features .texture_impl = TextureImplFeatures (
346- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
454+ valid_packed_dims = { PackedDim . WIDTH } ,
347455 )
348456 features .resize_fn = True
349457 return features
@@ -352,7 +460,7 @@ def register_rotary_emb_op(features: OpFeatures):
352460@update_features (exir_ops .edge .aten .view_copy .default )
353461def register_view_op (features : OpFeatures ):
354462 features .texture_impl = TextureImplFeatures (
355- uses_packed_dim = True ,
463+ valid_packed_dims = all_packed_dims ,
356464 )
357465 features .resize_fn = True
358466 return features
@@ -393,7 +501,7 @@ def register_view_op(features: OpFeatures):
393501)
394502def register_ported_op (features : OpFeatures ):
395503 features .texture_impl = TextureImplFeatures (
396- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
504+ valid_packed_dims = { PackedDim . CHANNELS } ,
397505 )
398506 return features
399507
@@ -408,15 +516,24 @@ def register_ported_op(features: OpFeatures):
408516)
409517def register_ported_ops_with_prepacking (features : OpFeatures ):
410518 features .texture_impl = TextureImplFeatures (
411- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
519+ valid_packed_dims = { PackedDim . CHANNELS } ,
412520 )
413521 features .handles_own_prepacking = True
414522 return features
415523
416524
417- ##
418- ## Utility Functions
419- ##
525+ #######################
526+ ## Utility functions ##
527+ #######################
528+
529+
530+ def has_impl (target : OpKey ) -> bool :
531+ if not isinstance (target , str ):
532+ if target not in vulkan_supported_ops :
533+ return target .name () in vulkan_supported_ops
534+ return target in vulkan_supported_ops
535+ else :
536+ return target in vulkan_supported_ops
420537
421538
422539def get_op_features (target : OpKey ) -> OpFeatures :
0 commit comments