88
99import operator
1010
11- from typing import Callable , Dict , List , Optional , Union
11+ from typing import Callable , Dict , Optional , Set , Union
1212
1313import executorch .backends .vulkan .custom_ops_lib # noqa
1414
1515import torch
1616
17- from executorch .backends .vulkan .serialization .vulkan_graph_schema import VkMemoryLayout
17+ from executorch .backends .vulkan .serialization .vulkan_graph_schema import (
18+ VkMemoryLayout ,
19+ VkStorageType ,
20+ )
21+
22+ from executorch .backends .vulkan .utils import (
23+ all_memory_layouts ,
24+ all_packed_dims ,
25+ PackedDim ,
26+ )
1827from executorch .exir .dialects ._ops import ops as exir_ops
1928
2029from executorch .exir .dialects .edge ._ops import EdgeOpOverload
2130from torch ._subclasses .fake_tensor import FakeTensor
2231
32+ ######################
33+ ## OpFeatures class ##
34+ ######################
35+
2336
2437def allow_node (node : torch .fx .Node ) -> bool :
2538 return True
2639
2740
2841class TextureImplFeatures :
2942 __slots__ = [
30- # Indicates if the compute shader is agnostic to the packed dimension
31- "uses_packed_dim" ,
32- # Indicates if the compute shader is agnostic to the texture axis mapping
43+ "valid_packed_dims" ,
3344 "uses_axis_map" ,
34- # Specifies a specific set of memory layouts that the shader supports. If it is
35- # and empty list, then the supported memory layouts can be inferred from the
36- # `uses_packed_dim` and `uses_axis_map` flags.
37- "supported_layouts" ,
3845 ]
3946
4047 def __init__ (
4148 self ,
42- uses_packed_dim : bool = False ,
4349 uses_axis_map : bool = False ,
44- supported_layouts : Optional [List [ VkMemoryLayout ]] = None ,
50+ valid_packed_dims : Optional [Set [ PackedDim ]] = None ,
4551 ):
46- self .uses_packed_dim : bool = uses_packed_dim
4752 self .uses_axis_map : bool = uses_axis_map
48- self .supported_layouts : Optional [List [VkMemoryLayout ]] = supported_layouts
53+ self .valid_packed_dims = set ()
54+ if valid_packed_dims is not None :
55+ self .valid_packed_dims = valid_packed_dims
56+
57+ def valid_memory_layouts (self ) -> Set [VkMemoryLayout ]:
58+ layouts = set ()
59+
60+ if PackedDim .WIDTH in self .valid_packed_dims :
61+ layouts .add (VkMemoryLayout .TENSOR_WIDTH_PACKED )
62+
63+ if PackedDim .HEIGHT in self .valid_packed_dims :
64+ layouts .add (VkMemoryLayout .TENSOR_HEIGHT_PACKED )
65+
66+ if PackedDim .CHANNELS in self .valid_packed_dims :
67+ layouts .add (VkMemoryLayout .TENSOR_CHANNELS_PACKED )
68+
69+ return layouts
4970
5071
5172class OpFeatures :
@@ -58,6 +79,9 @@ class OpFeatures:
5879 # bool indicating if the operator has a resize function, which allows it to
5980 # support dynamic shape tensors.
6081 "resize_fn" ,
82+ # Optimal
83+ "optimal_storage" ,
84+ "optimal_layout" ,
6185 # bool indicating if the operator handles its own prepacking. If this is True,
6286 # then the insert_prepack_nodes pass will not insert prepack nodes for the args
6387 # of the op.
@@ -72,17 +96,64 @@ def __init__(
7296 texture_impl : Optional [TextureImplFeatures ] = None ,
7397 buffer_impl : bool = False ,
7498 resize_fn : bool = False ,
99+ optimal_storage : Optional [VkStorageType ] = None ,
100+ optimal_layout : Optional [VkMemoryLayout ] = None ,
75101 handles_own_prepacking : bool = False ,
76102 check_node_fn : Optional [Callable ] = None ,
77103 ):
78104 self .texture_impl : Optional [TextureImplFeatures ] = texture_impl
79105 self .buffer_impl : bool = buffer_impl
80106 self .resize_fn : bool = resize_fn
107+ self .optimal_storage : Optional [VkStorageType ] = optimal_storage
108+ self .optimal_layout : Optional [VkMemoryLayout ] = optimal_layout
81109 self .handles_own_prepacking : bool = handles_own_prepacking
82110 self .check_node_fn : Callable = allow_node
83111 if check_node_fn is not None :
84112 self .check_node_fn = check_node_fn
85113
114+ def propose_storage_type (self ) -> Optional [VkStorageType ]:
115+ if self .optimal_storage is not None :
116+ return self .optimal_storage
117+
118+ if self .texture_impl is not None and not self .buffer_impl :
119+ return VkStorageType .TEXTURE_3D
120+ elif self .buffer_impl and self .texture_impl is None :
121+ return VkStorageType .BUFFER
122+
123+ return None
124+
125+ def supported_storage_types (self ) -> Set [VkStorageType ]:
126+ storage_types = set ()
127+ if self .texture_impl is not None :
128+ storage_types .add (VkStorageType .TEXTURE_3D )
129+ if self .buffer_impl :
130+ storage_types .add (VkStorageType .BUFFER )
131+
132+ return storage_types
133+
134+ def propose_memory_layout (self , storage : VkStorageType ) -> Optional [VkMemoryLayout ]:
135+ if self .optimal_layout is not None :
136+ return self .optimal_layout
137+
138+ if storage == VkStorageType .TEXTURE_3D :
139+ assert self .texture_impl is not None
140+ possible_layouts = self .texture_impl .valid_memory_layouts ()
141+ if len (possible_layouts ) == 1 :
142+ return next (iter (possible_layouts ))
143+
144+ return None
145+
146+ def supported_memory_layouts (self , storage : VkStorageType ) -> Set [VkMemoryLayout ]:
147+ if storage == VkStorageType .TEXTURE_3D :
148+ assert self .texture_impl is not None
149+ return self .texture_impl .valid_memory_layouts ()
150+ else :
151+ return all_memory_layouts
152+
153+
154+ #######################
155+ ## Operator Registry ##
156+ #######################
86157
87158OpKey = Union [str , torch ._ops .OpOverload , EdgeOpOverload ]
88159
@@ -122,8 +193,8 @@ def update_features_impl(op: OpKey):
122193)
123194def register_ephemeral_op (features : OpFeatures ):
124195 features .texture_impl = TextureImplFeatures (
125- uses_packed_dim = True ,
126196 uses_axis_map = True ,
197+ valid_packed_dims = all_packed_dims ,
127198 )
128199 features .buffer_impl = True
129200 features .resize_fn = True
@@ -143,8 +214,8 @@ def register_ephemeral_op(features: OpFeatures):
143214)
144215def register_binary_op (features : OpFeatures ):
145216 features .texture_impl = TextureImplFeatures (
146- uses_packed_dim = True ,
147217 uses_axis_map = True ,
218+ valid_packed_dims = all_packed_dims ,
148219 )
149220 features .resize_fn = True
150221 return features
@@ -170,8 +241,8 @@ def register_binary_op(features: OpFeatures):
170241)
171242def register_unary_op (features : OpFeatures ):
172243 features .texture_impl = TextureImplFeatures (
173- uses_packed_dim = True ,
174244 uses_axis_map = True ,
245+ valid_packed_dims = all_packed_dims ,
175246 )
176247 features .buffer_impl = True
177248 features .resize_fn = True
@@ -181,8 +252,8 @@ def register_unary_op(features: OpFeatures):
181252@update_features (exir_ops .edge .aten ._to_copy .default )
182253def register_to_copy_op (features : OpFeatures ):
183254 features .texture_impl = TextureImplFeatures (
184- uses_packed_dim = True ,
185255 uses_axis_map = True ,
256+ valid_packed_dims = all_packed_dims ,
186257 )
187258 features .resize_fn = True
188259
@@ -220,40 +291,43 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
220291)
221292def register_mm_op (features : OpFeatures ):
222293 features .texture_impl = TextureImplFeatures (
223- uses_packed_dim = False ,
224294 uses_axis_map = True ,
225- supported_layouts = [
226- VkMemoryLayout . TENSOR_WIDTH_PACKED ,
227- VkMemoryLayout . TENSOR_CHANNELS_PACKED ,
228- ] ,
295+ valid_packed_dims = {
296+ PackedDim . WIDTH ,
297+ PackedDim . CHANNELS ,
298+ } ,
229299 )
230300 features .buffer_impl = True
231301 features .resize_fn = True
302+ features .optimal_storage = VkStorageType .TEXTURE_3D
303+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
232304 features .handles_own_prepacking = True
233305 return features
234306
235307
236308@update_features (exir_ops .edge .aten ._weight_int8pack_mm .default )
237309def register_int8_mm_op (features : OpFeatures ):
238310 features .texture_impl = TextureImplFeatures (
239- uses_packed_dim = False ,
240311 uses_axis_map = False ,
241- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
312+ valid_packed_dims = { PackedDim . WIDTH } ,
242313 )
243314 features .buffer_impl = True
244315 features .resize_fn = True
316+ features .optimal_storage = VkStorageType .TEXTURE_3D
317+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
245318 features .handles_own_prepacking = True
246319 return features
247320
248321
249322@update_features (exir_ops .edge .et_vk .linear_weight_int4 .default )
250323def register_int4_mm_op (features : OpFeatures ):
251324 features .texture_impl = TextureImplFeatures (
252- uses_packed_dim = False ,
253325 uses_axis_map = False ,
254- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
326+ valid_packed_dims = { PackedDim . WIDTH } ,
255327 )
256328 features .resize_fn = True
329+ features .optimal_storage = VkStorageType .TEXTURE_3D
330+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
257331 features .handles_own_prepacking = True
258332 return features
259333
@@ -266,7 +340,7 @@ def register_int4_mm_op(features: OpFeatures):
266340)
267341def register_softmax_op (features : OpFeatures ):
268342 features .texture_impl = TextureImplFeatures (
269- uses_packed_dim = True ,
343+ valid_packed_dims = all_packed_dims ,
270344 )
271345 features .resize_fn = True
272346 return features
@@ -282,7 +356,7 @@ def register_softmax_op(features: OpFeatures):
282356)
283357def register_reduce_op (features : OpFeatures ):
284358 features .texture_impl = TextureImplFeatures (
285- uses_packed_dim = True ,
359+ valid_packed_dims = all_packed_dims ,
286360 )
287361 features .resize_fn = True
288362
@@ -309,7 +383,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
309383)
310384def register_2d_pool_op (features : OpFeatures ):
311385 features .texture_impl = TextureImplFeatures (
312- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
386+ valid_packed_dims = { PackedDim . CHANNELS } ,
313387 )
314388 features .resize_fn = True
315389 return features
@@ -323,27 +397,31 @@ def register_2d_pool_op(features: OpFeatures):
323397)
324398def register_convolution_op (features : OpFeatures ):
325399 features .texture_impl = TextureImplFeatures (
326- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
400+ valid_packed_dims = { PackedDim . CHANNELS } ,
327401 )
328402 features .resize_fn = True
403+ features .optimal_storage = VkStorageType .TEXTURE_3D
404+ features .optimal_layout = VkMemoryLayout .TENSOR_CHANNELS_PACKED
329405 features .handles_own_prepacking = True
330406 return features
331407
332408
333409@update_features ("llama::sdpa_with_kv_cache" )
334410def register_sdpa_op (features : OpFeatures ):
335411 features .texture_impl = TextureImplFeatures (
336- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
412+ valid_packed_dims = { PackedDim . WIDTH } ,
337413 )
338414 features .resize_fn = True
415+ features .optimal_storage = VkStorageType .TEXTURE_3D
416+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
339417 features .handles_own_prepacking = True
340418 return features
341419
342420
343421@update_features (exir_ops .edge .et_vk .apply_rotary_emb .default )
344422def register_rotary_emb_op (features : OpFeatures ):
345423 features .texture_impl = TextureImplFeatures (
346- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
424+ valid_packed_dims = { PackedDim . WIDTH } ,
347425 )
348426 features .resize_fn = True
349427 return features
@@ -352,7 +430,7 @@ def register_rotary_emb_op(features: OpFeatures):
352430@update_features (exir_ops .edge .aten .view_copy .default )
353431def register_view_op (features : OpFeatures ):
354432 features .texture_impl = TextureImplFeatures (
355- uses_packed_dim = True ,
433+ valid_packed_dims = all_packed_dims ,
356434 )
357435 features .resize_fn = True
358436 return features
@@ -393,7 +471,7 @@ def register_view_op(features: OpFeatures):
393471)
394472def register_ported_op (features : OpFeatures ):
395473 features .texture_impl = TextureImplFeatures (
396- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
474+ valid_packed_dims = { PackedDim . CHANNELS } ,
397475 )
398476 return features
399477
@@ -408,15 +486,24 @@ def register_ported_op(features: OpFeatures):
408486)
409487def register_ported_ops_with_prepacking (features : OpFeatures ):
410488 features .texture_impl = TextureImplFeatures (
411- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
489+ valid_packed_dims = { PackedDim . CHANNELS } ,
412490 )
413491 features .handles_own_prepacking = True
414492 return features
415493
416494
417- ##
418- ## Utility Functions
419- ##
495+ #######################
496+ ## Utility functions ##
497+ #######################
498+
499+
500+ def has_impl (target : OpKey ) -> bool :
501+ if not isinstance (target , str ):
502+ if target not in vulkan_supported_ops :
503+ return target .name () in vulkan_supported_ops
504+ return target in vulkan_supported_ops
505+ else :
506+ return target in vulkan_supported_ops
420507
421508
422509def get_op_features (target : OpKey ) -> OpFeatures :
@@ -430,5 +517,13 @@ def get_op_features(target: OpKey) -> OpFeatures:
430517 return vulkan_supported_ops [target ]
431518
432519
520+ def optimal_storage_type (target : OpKey ) -> Optional [VkStorageType ]:
521+ return get_op_features (target ).optimal_storage
522+
523+
524+ def optimal_memory_layout (target : OpKey ) -> Optional [VkMemoryLayout ]:
525+ return get_op_features (target ).optimal_layout
526+
527+
433528def handles_own_prepacking (target : OpKey ) -> bool :
434529 return get_op_features (target ).handles_own_prepacking
0 commit comments