@@ -420,6 +420,119 @@ def register_softmax_op():
420420 )
421421
422422
423+ def get_dims_reduced (node : torch .fx .Node ) -> Union [int , List [int ]]:
424+ ndim = utils .ndim_of (node .args [0 ])
425+ assert ndim is not None
426+ dims_reduced = None
427+ if len (node .args ) >= 1 :
428+ dims_reduced = node .args [1 ]
429+
430+ # If dim_list is None, return a list containing all the dims of the tensor
431+ if dims_reduced is None :
432+ dims_reduced = list (range (ndim ))
433+
434+ # Special case for reducing tensors with shape [1, N] - this is equivalent to
435+ # reducing the last dim.
436+ if utils .is_unsqueezed_vector (node ) and ndim == 2 :
437+ dims_reduced = 1
438+
439+ if isinstance (dims_reduced , (list , tuple )) and len (dims_reduced ) == 1 :
440+ dims_reduced = dims_reduced [0 ]
441+
442+ assert isinstance (dims_reduced , (int , list , tuple ))
443+ return utils .normalize_dims (dims_reduced , ndim )
444+
445+
446+ def get_keepdim_setting (node : torch .fx .Node ) -> bool :
447+ for arg in node .args :
448+ if isinstance (arg , bool ):
449+ return arg
450+
451+ # Assume false by default
452+ return False
453+
454+
455+ def is_reduce_node_supported_by_per_row_impl (node : torch .fx .Node ) -> bool :
456+ """
457+ Checks if a reduction node is supported by the Vulkan backend's reduce per row
458+ special case implementation.
459+ """
460+ input_ndim = utils .ndim_of (node .args [0 ])
461+ assert input_ndim is not None
462+ dims_reduced = get_dims_reduced (node )
463+
464+ return dims_reduced == input_ndim - 1
465+
466+
467+ def is_reduce_node_supported_by_general_impl (node : torch .fx .Node ) -> bool :
468+ dims_reduced = get_dims_reduced (node )
469+ # Only 1D and 2D reductions are supported at the moment.
470+ if isinstance (dims_reduced , (list , tuple )) and len (dims_reduced ) > 2 :
471+ return False
472+
473+ keepdim = get_keepdim_setting (node )
474+ # keepdim = False is not supported yet for general implementation
475+ if isinstance (keepdim , bool ) and not keepdim :
476+ return False
477+
478+ return True
479+
480+
481+ def is_reduce_node_supported (node : torch .fx .Node ) -> bool :
482+ # 0-dim output unsupported at the moment
483+ if utils .ndim_of (node ) == 0 :
484+ return False
485+
486+ return is_reduce_node_supported_by_per_row_impl (
487+ node
488+ ) or is_reduce_node_supported_by_general_impl (node )
489+
490+
491+ def pick_storage_for_reduce (node : torch .fx .Node ):
492+ inputs_storage = utils .NO_STORAGE
493+ outputs_storage = utils .NO_STORAGE
494+
495+ ndim = utils .ndim_of (node .args [0 ])
496+ dim_list = node .args [1 ]
497+
498+ if is_reduce_node_supported_by_general_impl (node ):
499+ inputs_storage = inputs_storage .make_union (utils .ANY_TEXTURE )
500+ outputs_storage = inputs_storage
501+
502+ # For 1D reductions of the last dim, a special reduce per row case is implemented
503+ # for buffer backed tensors.
504+ if is_reduce_node_supported_by_per_row_impl (node ):
505+ inputs_storage = inputs_storage .make_union (utils .CONTIGUOUS_BUFFER )
506+ outputs_storage = inputs_storage
507+ return inputs_storage , outputs_storage
508+
509+ # For 2D reductions, the packed dimension cannot be one of the reduced dims
510+ if isinstance (dim_list , (list , tuple )) and len (dim_list ) == 2 :
511+ # pyre-ignore[6]
512+ reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
513+ # pyre-ignore[6]
514+ reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
515+
516+ possible_packed_dims = {0 , 1 , 2 }
517+ possible_packed_dims .discard (reduce_dim1_whcn )
518+ possible_packed_dims .discard (reduce_dim2_whcn )
519+
520+ packed_dim = possible_packed_dims .pop ()
521+ assert packed_dim in [0 , 1 , 2 ]
522+
523+ if packed_dim == 0 :
524+ inputs_storage = utils .WIDTH_PACKED_TEXTURE
525+ outputs_storage = utils .WIDTH_PACKED_TEXTURE
526+ elif packed_dim == 1 :
527+ inputs_storage = utils .HEIGHT_PACKED_TEXTURE
528+ outputs_storage = utils .HEIGHT_PACKED_TEXTURE
529+ else :
530+ inputs_storage = utils .CHANNELS_PACKED_TEXTURE
531+ outputs_storage = utils .CHANNELS_PACKED_TEXTURE
532+
533+ return inputs_storage , outputs_storage
534+
535+
423536@update_features (
424537 [
425538 exir_ops .edge .aten .mean .dim ,
@@ -429,66 +542,12 @@ def register_softmax_op():
429542 ]
430543)
431544def register_reduce_op ():
432- def check_reduce_node (node : torch .fx .Node ) -> bool :
433- # Only one argument implies that the reduction is over the entire tensor, which
434- # is not supported yet.
435- if len (node .args ) == 1 :
436- return False
437-
438- dim_list = node .args [1 ]
439- # Only 1D and 2D reductions are supported at the moment.
440- if isinstance (dim_list , list ) and len (dim_list ) > 2 :
441- return False
442-
443- def try_find_keepdim_arg (node : torch .fx .Node ) -> bool :
444- for arg in node .args :
445- if isinstance (arg , bool ):
446- return arg
447-
448- # Assume false by default
449- return False
450-
451- keepdim = try_find_keepdim_arg (node )
452- if isinstance (keepdim , bool ) and not keepdim :
453- return False
454-
455- return True
456-
457- def pick_io_storage_for_reduce (node : torch .fx .Node ):
458- inputs_storage = utils .ANY_TEXTURE
459- outputs_storage = utils .ANY_TEXTURE
460-
461- input_tensor = node .args [0 ]
462- ndim = input_tensor .meta ["val" ].ndim
463- dim_list = node .args [1 ]
464- if isinstance (dim_list , list ) and len (dim_list ) == 2 :
465- reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
466- reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
467-
468- possible_packed_dims = {0 , 1 , 2 }
469- possible_packed_dims .discard (reduce_dim1_whcn )
470- possible_packed_dims .discard (reduce_dim2_whcn )
471-
472- packed_dim = possible_packed_dims .pop ()
473- assert packed_dim in [0 , 1 , 2 ]
474-
475- if packed_dim == 0 :
476- inputs_storage = utils .WIDTH_PACKED_TEXTURE
477- outputs_storage = utils .WIDTH_PACKED_TEXTURE
478- elif packed_dim == 1 :
479- inputs_storage = utils .HEIGHT_PACKED_TEXTURE
480- outputs_storage = utils .HEIGHT_PACKED_TEXTURE
481- else :
482- inputs_storage = utils .CHANNELS_PACKED_TEXTURE
483- outputs_storage = utils .CHANNELS_PACKED_TEXTURE
484-
485- return inputs_storage , outputs_storage
486545
487546 return OpFeatures (
488547 inputs_storage = utils .ANY_TEXTURE ,
489548 supports_resize = True ,
490- are_node_inputs_supported_fn = check_reduce_node ,
491- pick_io_storage_fn = pick_io_storage_for_reduce ,
549+ are_node_inputs_supported_fn = is_reduce_node_supported ,
550+ pick_io_storage_fn = pick_storage_for_reduce ,
492551 )
493552
494553
@@ -515,6 +574,7 @@ def register_2d_pool_op():
515574def register_convolution_op ():
516575 def check_conv_node (node : torch .fx .Node ) -> bool :
517576 x = node .args [0 ]
577+ assert isinstance (x , torch .fx .Node )
518578 x_shape = x .meta ["val" ].size ()
519579 # 4-D input implies 2D convolution
520580 if len (x_shape ) == 4 :
0 commit comments