Skip to content

Commit 5797608

Browse files
authored
[ET-VK] Fix Build Errors in Vulkan Backend (#13170)
This change fixes build issues that arose from the addition of 2D reduction to the Vulkan backend: 112a09f ```get_tensor()``` was moved to be a protected member of ComputeGraph a few minutes before the above commit got merged. This change also slightly modifies op_registry.py to have a more conservative approach of allowing 2D reduction to be delegated. cc @SS-JIA @manuelcandales @cbilgin
1 parent b114f9c commit 5797608

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

backends/vulkan/op_registry.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -384,26 +384,14 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
384384
memory_layout = utils.get_node_memory_layout(node)
385385

386386
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
387-
if memory_layout is not None:
388-
for dim in dim_list:
389-
# For WIDTH_PACKED layout, dimension 3 (W) is packed
390-
# For HEIGHT_PACKED layout, dimension 2 (H) is packed
391-
# For CHANNELS_PACKED layout, dimension 1 (C) is packed
392-
if (
393-
(
394-
memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED
395-
and dim == 3
396-
)
397-
or (
398-
memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED
399-
and dim == 2
400-
)
401-
or (
402-
memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED
403-
and dim == 1
404-
)
405-
):
406-
return False
387+
if (
388+
memory_layout is not None
389+
and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT
390+
):
391+
# For now only default layout is supported for 2D reduction.
392+
# Because we can't determine if the input is NCHW or NHWC here,
393+
# assume the reduction dimension is packed so we cannot support it.
394+
return False
407395
except (AssertionError, KeyError, AttributeError):
408396
# If we can't get memory layout information, we'll assume the dims aren't packed
409397
pass

backends/vulkan/runtime/graph/ops/impl/Reduce.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,19 @@ void resize_reduce2d_node(
3737
ComputeGraph* graph,
3838
const std::vector<ArgGroup>& args,
3939
const std::vector<ValueRef>& resize_args) {
40-
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
41-
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
40+
const ValueRef out = args.at(0).refs.at(0);
41+
const ValueRef in = args.at(1).refs.at(0);
4242

4343
// Extract the dimensions to reduce over
4444
const std::vector<int64_t> dims_list =
4545
graph->extract_int_or_symint_list(resize_args.at(0));
4646
int32_t reduce_dim1_nchw = dims_list[0];
4747
int32_t reduce_dim2_nchw = dims_list[1];
4848

49-
std::vector<int64_t> new_sizes = in->sizes();
49+
std::vector<int64_t> new_sizes = graph->sizes_of(in);
5050
new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1;
5151
new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1;
52-
out->virtual_resize(new_sizes);
52+
graph->virtual_resize(out, new_sizes);
5353
}
5454

5555
utils::uvec3 reduce_global_wg_size(

0 commit comments

Comments
 (0)