1616
1717import torch
1818
19- from executorch .backends .vulkan .serialization .vulkan_graph_schema import VkMemoryLayout
20-
2119from executorch .exir .dialects ._ops import ops as exir_ops
2220
2321from executorch .exir .dialects .edge ._ops import EdgeOpOverload
@@ -48,6 +46,9 @@ class OpFeatures:
4846 # Optional check function used during partitioning to determine if a node's
4947 # inputs are supported by the operator implementation.
5048 "are_node_inputs_supported_fn" ,
49+ # Optional function to determine valid representation sets for input and outputs
50+ # once a node's actual inputs are known.
51+ "pick_io_storage_fn" ,
5152 ]
5253
5354 def __init__ (
@@ -61,6 +62,7 @@ def __init__(
6162 supports_resize : bool = False ,
6263 supports_prepacking : bool = False ,
6364 are_node_inputs_supported_fn : Optional [Callable ] = allow_node ,
65+ pick_io_storage_fn : Optional [Callable ] = None ,
6466 ):
6567 self .inputs_storage : utils .TensorRepSetList = utils .TensorRepSetList (
6668 inputs_storage if inputs_storage is not None else []
@@ -77,15 +79,21 @@ def __init__(
7779 self .supports_prepacking = supports_prepacking
7880
7981 self .are_node_inputs_supported_fn = are_node_inputs_supported_fn
82+ self .pick_io_storage_fn = pick_io_storage_fn
8083
8184 def make_op_repsets (
8285 self ,
8386 op_node : torch .fx .Node ,
8487 texture_limits : utils .ImageExtents = utils .DEFAULT_TEXTURE_LIMITS ,
8588 ) -> utils .OpRepSets :
86- return utils .OpRepSets (
87- self .inputs_storage , self .outputs_storage , op_node , texture_limits
88- )
89+ inputs_storage = self .inputs_storage
90+ outputs_storage = self .outputs_storage
91+ if self .pick_io_storage_fn is not None :
92+ i_storage , o_storage = self .pick_io_storage_fn (op_node )
93+ inputs_storage = utils .TensorRepSetList (i_storage )
94+ outputs_storage = utils .TensorRepSetList (o_storage )
95+
96+ return utils .OpRepSets (inputs_storage , outputs_storage , op_node , texture_limits )
8997
9098
9199#######################
@@ -410,28 +418,16 @@ def register_softmax_op():
410418)
411419def register_reduce_op ():
412420 def check_reduce_node (node : torch .fx .Node ) -> bool :
421+ # Only one argument implies that the reduction is over the entire tensor, which
422+ # is not supported yet.
423+ if len (node .args ) == 1 :
424+ return False
425+
413426 dim_list = node .args [1 ]
427+ # Only 1D and 2D reductions are supported at the moment.
414428 if isinstance (dim_list , list ) and len (dim_list ) > 2 :
415429 return False
416430
417- if isinstance (dim_list , list ) and len (dim_list ) == 2 :
418- # Try to get the memory layout for this node
419- try :
420- memory_layout = utils .get_node_memory_layout (node )
421-
422- # If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
423- if (
424- memory_layout is not None
425- and memory_layout != VkMemoryLayout .DEFAULT_LAYOUT
426- ):
427- # For now only default layout is supported for 2D reduction.
428- # Because we can't determine if the input is NCHW or NHWC here,
429- # assume the reduction dimension is packed so we cannot support it.
430- return False
431- except (AssertionError , KeyError , AttributeError ):
432- # If we can't get memory layout information, we'll assume the dims aren't packed
433- pass
434-
435431 def try_find_keepdim_arg (node : torch .fx .Node ) -> bool :
436432 for arg in node .args :
437433 if isinstance (arg , bool ):
@@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
446442
447443 return True
448444
445+ def pick_io_storage_for_reduce (node : torch .fx .Node ):
446+ inputs_storage = utils .ANY_TEXTURE
447+ outputs_storage = utils .ANY_TEXTURE
448+
449+ input_tensor = node .args [0 ]
450+ ndim = input_tensor .meta ["val" ].ndim
451+ dim_list = node .args [1 ]
452+ if isinstance (dim_list , list ) and len (dim_list ) == 2 :
453+ reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
454+ reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
455+
456+ possible_packed_dims = {0 , 1 , 2 }
457+ possible_packed_dims .discard (reduce_dim1_whcn )
458+ possible_packed_dims .discard (reduce_dim2_whcn )
459+
460+ packed_dim = possible_packed_dims .pop ()
461+ assert packed_dim in [0 , 1 , 2 ]
462+
463+ if packed_dim == 0 :
464+ inputs_storage = utils .WIDTH_PACKED_TEXTURE
465+ outputs_storage = utils .WIDTH_PACKED_TEXTURE
466+ elif packed_dim == 1 :
467+ inputs_storage = utils .HEIGHT_PACKED_TEXTURE
468+ outputs_storage = utils .HEIGHT_PACKED_TEXTURE
469+ else :
470+ inputs_storage = utils .CHANNELS_PACKED_TEXTURE
471+ outputs_storage = utils .CHANNELS_PACKED_TEXTURE
472+
473+ return inputs_storage , outputs_storage
474+
449475 return OpFeatures (
450476 inputs_storage = utils .ANY_TEXTURE ,
451477 supports_resize = True ,
452478 are_node_inputs_supported_fn = check_reduce_node ,
479+ pick_io_storage_fn = pick_io_storage_for_reduce ,
453480 )
454481
455482
@@ -474,6 +501,23 @@ def register_2d_pool_op():
474501 ]
475502)
476503def register_convolution_op ():
504+ def check_conv_node (node : torch .fx .Node ) -> bool :
505+ x = node .args [0 ]
506+ x_shape = x .meta ["val" ].size ()
507+ # 4-D input implies 2D convolution
508+ if len (x_shape ) == 4 :
509+ batches = x .meta ["val" ].size ()[0 ]
510+ if batches != 1 :
511+ return False
512+ # 3-D input implies 1D convolution
513+ if len (x_shape ) == 3 :
514+ transpose = node .args [6 ]
515+ # Transposed 1D convolution is not supported yet
516+ if transpose :
517+ return False
518+
519+ return True
520+
477521 return OpFeatures (
478522 inputs_storage = [
479523 utils .CHANNELS_PACKED_TEXTURE , # input
@@ -490,6 +534,7 @@ def register_convolution_op():
490534 ],
491535 supports_resize = True ,
492536 supports_prepacking = True ,
537+ are_node_inputs_supported_fn = check_conv_node ,
493538 )
494539
495540
@@ -666,6 +711,7 @@ def register_ported_ops_with_prepacking():
666711 return OpFeatures (
667712 inputs_storage = utils .CHANNELS_PACKED_TEXTURE ,
668713 supports_prepacking = True ,
714+ supports_resize = True ,
669715 )
670716
671717
@@ -696,6 +742,7 @@ def register_ported_ops_with_prepacking_all_dims():
696742 return OpFeatures (
697743 inputs_storage = utils .ANY_TEXTURE ,
698744 supports_prepacking = True ,
745+ supports_resize = True ,
699746 )
700747
701748
0 commit comments