File tree Expand file tree Collapse file tree 2 files changed +12
-24
lines changed Expand file tree Collapse file tree 2 files changed +12
-24
lines changed Original file line number Diff line number Diff line change @@ -384,26 +384,14 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
384384 memory_layout = utils .get_node_memory_layout (node )
385385
386386 # If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
387- if memory_layout is not None :
388- for dim in dim_list :
389- # For WIDTH_PACKED layout, dimension 3 (W) is packed
390- # For HEIGHT_PACKED layout, dimension 2 (H) is packed
391- # For CHANNELS_PACKED layout, dimension 1 (C) is packed
392- if (
393- (
394- memory_layout == VkMemoryLayout .TENSOR_WIDTH_PACKED
395- and dim == 3
396- )
397- or (
398- memory_layout == VkMemoryLayout .TENSOR_HEIGHT_PACKED
399- and dim == 2
400- )
401- or (
402- memory_layout == VkMemoryLayout .TENSOR_CHANNELS_PACKED
403- and dim == 1
404- )
405- ):
406- return False
387+ if (
388+ memory_layout is not None
389+ and memory_layout != VkMemoryLayout .DEFAULT_LAYOUT
390+ ):
391+ # For now only default layout is supported for 2D reduction.
392+ # Because we can't determine if the input is NCHW or NHWC here,
393+ # assume the reduction dimension is packed so we cannot support it.
394+ return False
407395 except (AssertionError , KeyError , AttributeError ):
408396 # If we can't get memory layout information, we'll assume the dims aren't packed
409397 pass
Original file line number Diff line number Diff line change @@ -37,19 +37,19 @@ void resize_reduce2d_node(
3737 ComputeGraph* graph,
3838 const std::vector<ArgGroup>& args,
3939 const std::vector<ValueRef>& resize_args) {
40- vTensorPtr out = graph-> get_tensor ( args[ 0 ] .refs [ 0 ] );
41- vTensorPtr in = graph-> get_tensor ( args[ 1 ] .refs [ 0 ] );
40+ const ValueRef out = args. at ( 0 ) .refs . at ( 0 );
41+ const ValueRef in = args. at ( 1 ) .refs . at ( 0 );
4242
4343 // Extract the dimensions to reduce over
4444 const std::vector<int64_t > dims_list =
4545 graph->extract_int_or_symint_list (resize_args.at (0 ));
4646 int32_t reduce_dim1_nchw = dims_list[0 ];
4747 int32_t reduce_dim2_nchw = dims_list[1 ];
4848
49- std::vector<int64_t > new_sizes = in-> sizes ( );
49+ std::vector<int64_t > new_sizes = graph-> sizes_of (in );
5050 new_sizes.at (normalize (reduce_dim1_nchw, new_sizes.size ())) = 1 ;
5151 new_sizes.at (normalize (reduce_dim2_nchw, new_sizes.size ())) = 1 ;
52- out ->virtual_resize (new_sizes);
52+ graph ->virtual_resize (out, new_sizes);
5353}
5454
5555utils::uvec3 reduce_global_wg_size (
You can’t perform that action at this time.
0 commit comments