@@ -420,6 +420,114 @@ def register_softmax_op():
420420 )
421421
422422
423+ def get_dims_reduced (node : torch .fx .Node ) -> Union [int , List [int ]]:
424+ ndim = utils .ndim_of (node .args [0 ])
425+ dims_reduced = None
426+ if len (node .args ) >= 1 :
427+ dims_reduced = node .args [1 ]
428+
429+ # If dim_list is None, return a list containing all the dims of the tensor
430+ if dims_reduced is None :
431+ dims_reduced = list (range (ndim ))
432+
433+ # Special case for reducing tensors with shape [1, N] - this is equivalent to
434+ # reducing the last dim.
435+ if utils .is_unsqueezed_vector (node ) and ndim == 2 :
436+ dims_reduced = 1
437+
438+ if isinstance (dims_reduced , (list , tuple )) and len (dims_reduced ) == 1 :
439+ dims_reduced = dims_reduced [0 ]
440+
441+ return utils .normalize_dims (dims_reduced , ndim )
442+
443+
444+ def get_keepdim_setting (node : torch .fx .Node ) -> bool :
445+ for arg in node .args :
446+ if isinstance (arg , bool ):
447+ return arg
448+
449+ # Assume false by default
450+ return False
451+
452+
453+ def is_reduce_node_supported_by_per_row_impl (node : torch .fx .Node ) -> bool :
454+ """
455+ Checks if a reduction node is supported by the Vulkan backend's reduce per row
456+ special case implementation.
457+ """
458+ input_ndim = utils .ndim_of (node .args [0 ])
459+ dims_reduced = get_dims_reduced (node )
460+
461+ return dims_reduced == input_ndim - 1
462+
463+
464+ def is_reduce_node_supported_by_general_impl (node : torch .fx .Node ) -> bool :
465+ dims_reduced = get_dims_reduced (node )
466+ # Only 1D and 2D reductions are supported at the moment.
467+ if isinstance (dims_reduced , (list , tuple )) and len (dims_reduced ) > 2 :
468+ return False
469+
470+ keepdim = get_keepdim_setting (node )
471+ # keepdim = False is not supported yet for general implementation
472+ if isinstance (keepdim , bool ) and not keepdim :
473+ return False
474+
475+ return True
476+
477+
478+ def is_reduce_node_supported (node : torch .fx .Node ) -> bool :
479+ # 0-dim output unsupported at the moment
480+ if utils .ndim_of (node ) == 0 :
481+ return False
482+
483+ return is_reduce_node_supported_by_per_row_impl (
484+ node
485+ ) or is_reduce_node_supported_by_general_impl (node )
486+
487+
488+ def pick_storage_for_reduce (node : torch .fx .Node ):
489+ inputs_storage = utils .NO_STORAGE
490+ outputs_storage = utils .NO_STORAGE
491+
492+ ndim = utils .ndim_of (node .args [0 ])
493+ dim_list = node .args [1 ]
494+
495+ if is_reduce_node_supported_by_general_impl (node ):
496+ inputs_storage = inputs_storage .make_union (utils .ANY_TEXTURE )
497+ outputs_storage = inputs_storage
498+
499+ # For 1D reductions of the last dim, a special reduce per row case is implemented
500+ # for buffer backed tensors.
501+ if is_reduce_node_supported_by_per_row_impl (node ):
502+ inputs_storage = inputs_storage .make_union (utils .CONTIGUOUS_BUFFER )
503+ outputs_storage = inputs_storage
504+ return inputs_storage , outputs_storage
505+
506+ # For 2D reductions, the packed dimension cannot be one of the reduced dims
507+ if isinstance (dim_list , (list , tuple )) and len (dim_list ) == 2 :
508+ reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
509+ reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
510+
511+ possible_packed_dims = {0 , 1 , 2 }
512+ possible_packed_dims .discard (reduce_dim1_whcn )
513+ possible_packed_dims .discard (reduce_dim2_whcn )
514+
515+ packed_dim = possible_packed_dims .pop ()
516+ assert packed_dim in [0 , 1 , 2 ]
517+
518+ if packed_dim == 0 :
519+ inputs_storage = utils .WIDTH_PACKED_TEXTURE
520+ outputs_storage = utils .WIDTH_PACKED_TEXTURE
521+ elif packed_dim == 1 :
522+ inputs_storage = utils .HEIGHT_PACKED_TEXTURE
523+ outputs_storage = utils .HEIGHT_PACKED_TEXTURE
524+ else :
525+ inputs_storage = utils .CHANNELS_PACKED_TEXTURE
526+ outputs_storage = utils .CHANNELS_PACKED_TEXTURE
527+
528+ return inputs_storage , outputs_storage
529+
530+
423531@update_features (
424532 [
425533 exir_ops .edge .aten .mean .dim ,
@@ -429,66 +537,12 @@ def register_softmax_op():
429537 ]
430538)
431539def register_reduce_op ():
432- def check_reduce_node (node : torch .fx .Node ) -> bool :
433- # Only one argument implies that the reduction is over the entire tensor, which
434- # is not supported yet.
435- if len (node .args ) == 1 :
436- return False
437-
438- dim_list = node .args [1 ]
439- # Only 1D and 2D reductions are supported at the moment.
440- if isinstance (dim_list , list ) and len (dim_list ) > 2 :
441- return False
442-
443- def try_find_keepdim_arg (node : torch .fx .Node ) -> bool :
444- for arg in node .args :
445- if isinstance (arg , bool ):
446- return arg
447-
448- # Assume false by default
449- return False
450-
451- keepdim = try_find_keepdim_arg (node )
452- if isinstance (keepdim , bool ) and not keepdim :
453- return False
454-
455- return True
456-
457- def pick_io_storage_for_reduce (node : torch .fx .Node ):
458- inputs_storage = utils .ANY_TEXTURE
459- outputs_storage = utils .ANY_TEXTURE
460-
461- input_tensor = node .args [0 ]
462- ndim = input_tensor .meta ["val" ].ndim
463- dim_list = node .args [1 ]
464- if isinstance (dim_list , list ) and len (dim_list ) == 2 :
465- reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
466- reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
467-
468- possible_packed_dims = {0 , 1 , 2 }
469- possible_packed_dims .discard (reduce_dim1_whcn )
470- possible_packed_dims .discard (reduce_dim2_whcn )
471-
472- packed_dim = possible_packed_dims .pop ()
473- assert packed_dim in [0 , 1 , 2 ]
474-
475- if packed_dim == 0 :
476- inputs_storage = utils .WIDTH_PACKED_TEXTURE
477- outputs_storage = utils .WIDTH_PACKED_TEXTURE
478- elif packed_dim == 1 :
479- inputs_storage = utils .HEIGHT_PACKED_TEXTURE
480- outputs_storage = utils .HEIGHT_PACKED_TEXTURE
481- else :
482- inputs_storage = utils .CHANNELS_PACKED_TEXTURE
483- outputs_storage = utils .CHANNELS_PACKED_TEXTURE
484-
485- return inputs_storage , outputs_storage
486540
487541 return OpFeatures (
488542 inputs_storage = utils .ANY_TEXTURE ,
489543 supports_resize = True ,
490- are_node_inputs_supported_fn = check_reduce_node ,
491- pick_io_storage_fn = pick_io_storage_for_reduce ,
544+ are_node_inputs_supported_fn = is_reduce_node_supported ,
545+ pick_io_storage_fn = pick_storage_for_reduce ,
492546 )
493547
494548
0 commit comments