@@ -420,75 +420,131 @@ def register_softmax_op():
420420 )
421421
422422
423- @update_features (
424- [
425- exir_ops .edge .aten .mean .dim ,
426- exir_ops .edge .aten .sum .dim_IntList ,
427- exir_ops .edge .aten .amax .default ,
428- exir_ops .edge .aten .amin .default ,
429- ]
430- )
431- def register_reduce_op ():
432- def check_reduce_node (node : torch .fx .Node ) -> bool :
433- # Only one argument implies that the reduction is over the entire tensor, which
434- # is not supported yet.
435- if len (node .args ) == 1 :
436- return False
423+ def get_dims_reduced (node : torch .fx .Node ) -> Union [int , List [int ]]:
424+ ndim = utils .ndim_of (node .args [0 ])
425+ assert ndim is not None
426+ dims_reduced = None
427+ if len (node .args ) >= 2 :
428+ dims_reduced = node .args [1 ]
437429
438- dim_list = node .args [1 ]
439- # Only 1D and 2D reductions are supported at the moment.
440- if isinstance (dim_list , list ) and len (dim_list ) > 2 :
441- return False
430+ # If dim_list is None, return a list containing all the dims of the tensor
431+ if dims_reduced is None :
432+ dims_reduced = list (range (ndim ))
442433
443- def try_find_keepdim_arg ( node : torch . fx . Node ) -> bool :
444- for arg in node . args :
445- if isinstance ( arg , bool ) :
446- return arg
434+ # Special case for reducing tensors with shape [1, N] - this is equivalent to
435+ # reducing the last dim.
436+ if utils . is_unsqueezed_vector ( node ) and ndim == 2 :
437+ dims_reduced = 1
447438
448- # Assume false by default
449- return False
439+ if isinstance ( dims_reduced , ( list , tuple )) and len ( dims_reduced ) == 1 :
440+ dims_reduced = dims_reduced [ 0 ]
450441
451- keepdim = try_find_keepdim_arg (node )
452- if isinstance (keepdim , bool ) and not keepdim :
453- return False
442+ assert isinstance (dims_reduced , (int , list , tuple ))
443+ return utils .normalize_dims (dims_reduced , ndim )
454444
455- return True
456445
457- def pick_io_storage_for_reduce (node : torch .fx .Node ):
458- inputs_storage = utils .ANY_TEXTURE
459- outputs_storage = utils .ANY_TEXTURE
460-
461- input_tensor = node .args [0 ]
462- ndim = input_tensor .meta ["val" ].ndim
463- dim_list = node .args [1 ]
464- if isinstance (dim_list , list ) and len (dim_list ) == 2 :
465- reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
466- reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
467-
468- possible_packed_dims = {0 , 1 , 2 }
469- possible_packed_dims .discard (reduce_dim1_whcn )
470- possible_packed_dims .discard (reduce_dim2_whcn )
471-
472- packed_dim = possible_packed_dims .pop ()
473- assert packed_dim in [0 , 1 , 2 ]
474-
475- if packed_dim == 0 :
476- inputs_storage = utils .WIDTH_PACKED_TEXTURE
477- outputs_storage = utils .WIDTH_PACKED_TEXTURE
478- elif packed_dim == 1 :
479- inputs_storage = utils .HEIGHT_PACKED_TEXTURE
480- outputs_storage = utils .HEIGHT_PACKED_TEXTURE
481- else :
482- inputs_storage = utils .CHANNELS_PACKED_TEXTURE
483- outputs_storage = utils .CHANNELS_PACKED_TEXTURE
446+ def get_keepdim_setting (node : torch .fx .Node ) -> bool :
447+ for arg in node .args :
448+ if isinstance (arg , bool ):
449+ return arg
450+
451+ # Assume false by default
452+ return False
453+
454+
455+ def is_reduce_node_supported_by_per_row_impl (node : torch .fx .Node ) -> bool :
456+ """
457+ Checks if a reduction node is supported by the Vulkan backend's reduce per row
458+ special case implementation.
459+ """
460+ input_ndim = utils .ndim_of (node .args [0 ])
461+ assert input_ndim is not None
462+ dims_reduced = get_dims_reduced (node )
463+
464+ return dims_reduced == input_ndim - 1
465+
466+
467+ def is_reduce_node_supported_by_general_impl (node : torch .fx .Node ) -> bool :
468+ dims_reduced = get_dims_reduced (node )
469+ # Only 1D and 2D reductions are supported at the moment.
470+ if isinstance (dims_reduced , (list , tuple )) and len (dims_reduced ) > 2 :
471+ return False
472+
473+ keepdim = get_keepdim_setting (node )
474+ # keepdim = False is not supported yet for general implementation
475+ if isinstance (keepdim , bool ) and not keepdim :
476+ return False
477+
478+ return True
479+
480+
481+ def is_reduce_node_supported (node : torch .fx .Node ) -> bool :
482+ return is_reduce_node_supported_by_per_row_impl (
483+ node
484+ ) or is_reduce_node_supported_by_general_impl (node )
485+
484486
487+ def pick_storage_for_reduce (node : torch .fx .Node ):
488+ inputs_storage = utils .NO_STORAGE
489+ outputs_storage = utils .NO_STORAGE
490+
491+ ndim = utils .ndim_of (node .args [0 ])
492+ dim_list = get_dims_reduced (node )
493+
494+ if is_reduce_node_supported_by_general_impl (node ):
495+ inputs_storage = inputs_storage .make_union (utils .ANY_TEXTURE )
496+ outputs_storage = inputs_storage
497+
498+ # For 1D reductions of the last dim, a special reduce per row case is implemented
499+ # for buffer backed tensors.
500+ if is_reduce_node_supported_by_per_row_impl (node ):
501+ inputs_storage = inputs_storage .make_union (utils .CONTIGUOUS_BUFFER )
502+ outputs_storage = inputs_storage
485503 return inputs_storage , outputs_storage
486504
505+ # For 2D reductions, the packed dimension cannot be one of the reduced dims
506+ if isinstance (dim_list , (list , tuple )) and len (dim_list ) == 2 :
507+ # pyre-ignore[6]
508+ reduce_dim1_whcn = utils .nchw_dim_to_whcn_dim (dim_list [0 ], ndim )
509+ # pyre-ignore[6]
510+ reduce_dim2_whcn = utils .nchw_dim_to_whcn_dim (dim_list [1 ], ndim )
511+
512+ possible_packed_dims = {0 , 1 , 2 }
513+ possible_packed_dims .discard (reduce_dim1_whcn )
514+ possible_packed_dims .discard (reduce_dim2_whcn )
515+
516+ packed_dim = possible_packed_dims .pop ()
517+ assert packed_dim in [0 , 1 , 2 ]
518+
519+ if packed_dim == 0 :
520+ inputs_storage = utils .WIDTH_PACKED_TEXTURE
521+ outputs_storage = utils .WIDTH_PACKED_TEXTURE
522+ elif packed_dim == 1 :
523+ inputs_storage = utils .HEIGHT_PACKED_TEXTURE
524+ outputs_storage = utils .HEIGHT_PACKED_TEXTURE
525+ else :
526+ inputs_storage = utils .CHANNELS_PACKED_TEXTURE
527+ outputs_storage = utils .CHANNELS_PACKED_TEXTURE
528+
529+ return inputs_storage , outputs_storage
530+
531+
532+ @update_features (
533+ [
534+ exir_ops .edge .aten .mean .dim ,
535+ exir_ops .edge .aten .sum .dim_IntList ,
536+ exir_ops .edge .aten .amax .default ,
537+ exir_ops .edge .aten .amin .default ,
538+ exir_ops .edge .aten .argmax .default ,
539+ exir_ops .edge .aten .argmin .default ,
540+ ]
541+ )
542+ def register_reduce_op ():
487543 return OpFeatures (
488544 inputs_storage = utils .ANY_TEXTURE ,
489545 supports_resize = True ,
490- are_node_inputs_supported_fn = check_reduce_node ,
491- pick_io_storage_fn = pick_io_storage_for_reduce ,
546+ are_node_inputs_supported_fn = is_reduce_node_supported ,
547+ pick_io_storage_fn = pick_storage_for_reduce ,
492548 )
493549
494550
@@ -515,6 +571,7 @@ def register_2d_pool_op():
515571def register_convolution_op ():
516572 def check_conv_node (node : torch .fx .Node ) -> bool :
517573 x = node .args [0 ]
574+ assert isinstance (x , torch .fx .Node )
518575 x_shape = x .meta ["val" ].size ()
519576 # 4-D input implies 2D convolution
520577 if len (x_shape ) == 4 :
0 commit comments