16
16
17
17
import torch
18
18
19
- from executorch .backends .vulkan .serialization .vulkan_graph_schema import VkMemoryLayout
20
-
21
19
from executorch .exir .dialects ._ops import ops as exir_ops
22
20
23
21
from executorch .exir .dialects .edge ._ops import EdgeOpOverload
@@ -48,6 +46,9 @@ class OpFeatures:
48
46
# Optional check function used during partitioning to determine if a node's
49
47
# inputs are supported by the operator implementation.
50
48
"are_node_inputs_supported_fn" ,
49
+ # Optional function to determine valid representation sets for input and outputs
50
+ # once a node's actual inputs are known.
51
+ "pick_io_storage_fn" ,
51
52
]
52
53
53
54
def __init__ (
@@ -61,6 +62,7 @@ def __init__(
61
62
supports_resize : bool = False ,
62
63
supports_prepacking : bool = False ,
63
64
are_node_inputs_supported_fn : Optional [Callable ] = allow_node ,
65
+ pick_io_storage_fn : Optional [Callable ] = None ,
64
66
):
65
67
self .inputs_storage : utils .TensorRepSetList = utils .TensorRepSetList (
66
68
inputs_storage if inputs_storage is not None else []
@@ -77,15 +79,21 @@ def __init__(
77
79
self .supports_prepacking = supports_prepacking
78
80
79
81
self .are_node_inputs_supported_fn = are_node_inputs_supported_fn
82
+ self .pick_io_storage_fn = pick_io_storage_fn
80
83
81
84
def make_op_repsets (
82
85
self ,
83
86
op_node : torch .fx .Node ,
84
87
texture_limits : utils .ImageExtents = utils .DEFAULT_TEXTURE_LIMITS ,
85
88
) -> utils .OpRepSets :
86
- return utils .OpRepSets (
87
- self .inputs_storage , self .outputs_storage , op_node , texture_limits
88
- )
89
+ inputs_storage = self .inputs_storage
90
+ outputs_storage = self .outputs_storage
91
+ if self .pick_io_storage_fn is not None :
92
+ i_storage , o_storage = self .pick_io_storage_fn (op_node )
93
+ inputs_storage = utils .TensorRepSetList (i_storage )
94
+ outputs_storage = utils .TensorRepSetList (o_storage )
95
+
96
+ return utils .OpRepSets (inputs_storage , outputs_storage , op_node , texture_limits )
89
97
90
98
91
99
#######################
@@ -410,28 +418,16 @@ def register_softmax_op():
410
418
)
411
419
def register_reduce_op ():
412
420
def check_reduce_node (node : torch .fx .Node ) -> bool :
421
+ # Only one argument implies that the reduction is over the entire tensor, which
422
+ # is not supported yet.
423
+ if len (node .args ) == 1 :
424
+ return False
425
+
413
426
dim_list = node .args [1 ]
427
+ # Only 1D and 2D reductions are supported at the moment.
414
428
if isinstance (dim_list , list ) and len (dim_list ) > 2 :
415
429
return False
416
430
417
- if isinstance (dim_list , list ) and len (dim_list ) == 2 :
418
- # Try to get the memory layout for this node
419
- try :
420
- memory_layout = utils .get_node_memory_layout (node )
421
-
422
- # If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
423
- if (
424
- memory_layout is not None
425
- and memory_layout != VkMemoryLayout .DEFAULT_LAYOUT
426
- ):
427
- # For now only default layout is supported for 2D reduction.
428
- # Because we can't determine if the input is NCHW or NHWC here,
429
- # assume the reduction dimension is packed so we cannot support it.
430
- return False
431
- except (AssertionError , KeyError , AttributeError ):
432
- # If we can't get memory layout information, we'll assume the dims aren't packed
433
- pass
434
-
435
431
def try_find_keepdim_arg (node : torch .fx .Node ) -> bool :
436
432
for arg in node .args :
437
433
if isinstance (arg , bool ):
@@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
446
442
447
443
return True
448
444
445
+ def pick_io_storage_for_reduce (node : torch .fx .Node ):
446
+ inputs_storage = utils .ANY_TEXTURE
447
+ outputs_storage = utils .ANY_TEXTURE
448
+
449
+ input_tensor = node .args [0 ]
450
+ ndim = input_tensor .meta ["val" ].ndim
451
+ dim_list = node .args [1 ]
452
+ if isinstance (dim_list , list ) and len (dim_list ) == 2 :
453
+ reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
454
+ reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
455
+
456
+ possible_packed_dims = {0 , 1 , 2 }
457
+ possible_packed_dims .discard (reduce_dim1_whcn )
458
+ possible_packed_dims .discard (reduce_dim2_whcn )
459
+
460
+ packed_dim = possible_packed_dims .pop ()
461
+ assert packed_dim in [0 , 1 , 2 ]
462
+
463
+ if packed_dim == 0 :
464
+ inputs_storage = utils .WIDTH_PACKED_TEXTURE
465
+ outputs_storage = utils .WIDTH_PACKED_TEXTURE
466
+ elif packed_dim == 1 :
467
+ inputs_storage = utils .HEIGHT_PACKED_TEXTURE
468
+ outputs_storage = utils .HEIGHT_PACKED_TEXTURE
469
+ else :
470
+ inputs_storage = utils .CHANNELS_PACKED_TEXTURE
471
+ outputs_storage = utils .CHANNELS_PACKED_TEXTURE
472
+
473
+ return inputs_storage , outputs_storage
474
+
449
475
return OpFeatures (
450
476
inputs_storage = utils .ANY_TEXTURE ,
451
477
supports_resize = True ,
452
478
are_node_inputs_supported_fn = check_reduce_node ,
479
+ pick_io_storage_fn = pick_io_storage_for_reduce ,
453
480
)
454
481
455
482
@@ -474,6 +501,23 @@ def register_2d_pool_op():
474
501
]
475
502
)
476
503
def register_convolution_op ():
504
+ def check_conv_node (node : torch .fx .Node ) -> bool :
505
+ x = node .args [0 ]
506
+ x_shape = x .meta ["val" ].size ()
507
+ # 4-D input implies 2D convolution
508
+ if len (x_shape ) == 4 :
509
+ batches = x .meta ["val" ].size ()[0 ]
510
+ if batches != 1 :
511
+ return False
512
+ # 3-D input implies 1D convolution
513
+ if len (x_shape ) == 3 :
514
+ transpose = node .args [6 ]
515
+ # Transposed 1D convolution is not supported yet
516
+ if transpose :
517
+ return False
518
+
519
+ return True
520
+
477
521
return OpFeatures (
478
522
inputs_storage = [
479
523
utils .CHANNELS_PACKED_TEXTURE , # input
@@ -490,6 +534,7 @@ def register_convolution_op():
490
534
],
491
535
supports_resize = True ,
492
536
supports_prepacking = True ,
537
+ are_node_inputs_supported_fn = check_conv_node ,
493
538
)
494
539
495
540
@@ -666,6 +711,7 @@ def register_ported_ops_with_prepacking():
666
711
return OpFeatures (
667
712
inputs_storage = utils .CHANNELS_PACKED_TEXTURE ,
668
713
supports_prepacking = True ,
714
+ supports_resize = True ,
669
715
)
670
716
671
717
@@ -696,6 +742,7 @@ def register_ported_ops_with_prepacking_all_dims():
696
742
return OpFeatures (
697
743
inputs_storage = utils .ANY_TEXTURE ,
698
744
supports_prepacking = True ,
745
+ supports_resize = True ,
699
746
)
700
747
701
748
0 commit comments