Skip to content

Commit 01502f6

Browse files
committed
Update 2D reduction to only be enabled for non-packed dims
1 parent 22c69b9 commit 01502f6

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

backends/vulkan/op_registry.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
all_memory_layouts,
2424
all_packed_dims,
2525
PackedDim,
26+
get_node_memory_layout,
2627
)
2728
from executorch.exir.dialects._ops import ops as exir_ops
2829

@@ -535,6 +536,27 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
535536
if isinstance(dim_list, list) and len(dim_list) > 2:
536537
return False
537538

539+
if isinstance(dim_list, list) and len(dim_list) == 2:
540+
# Try to get the memory layout for this node
541+
try:
542+
memory_layout = get_node_memory_layout(node)
543+
544+
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
545+
if memory_layout is not None:
546+
for dim in dim_list:
547+
# For WIDTH_PACKED layout, dimension 3 (W) is packed
548+
if memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED and dim == 3:
549+
return False
550+
# For HEIGHT_PACKED layout, dimension 2 (H) is packed
551+
elif memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED and dim == 2:
552+
return False
553+
# For CHANNELS_PACKED layout, dimension 1 (C) is packed
554+
elif memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED and dim == 1:
555+
return False
556+
except (AssertionError, KeyError, AttributeError):
557+
# If we can't get memory layout information, we'll assume the dims aren't packed
558+
pass
559+
538560
keepdim = node.args[2]
539561
if isinstance(keepdim, bool) and not keepdim:
540562
return False

backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ int tid_to_smi(const ivec2 tid) {
5858
// with the accumulator.
5959
#define POSTPROCESS(accum) ${POSTPROCESS}
6060

61-
void reduce_2d(const ivec2 tid, ivec3 scan_pos) {
61+
void reduce_2d_non_packed_dim(const ivec2 tid, ivec3 scan_pos) {
6262
// shared memory index of this thread
6363
const int smi = tid_to_smi(tid);
6464

@@ -124,5 +124,5 @@ void main() {
124124
return;
125125
}
126126

127-
reduce_2d(tid, scan_pos);
127+
reduce_2d_non_packed_dim(tid, scan_pos);
128128
}

backends/vulkan/runtime/graph/ops/impl/Reduce.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ void add_reduce2d_node(
179179
reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim);
180180
reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim);
181181

182+
// Check that none of the reduction dims are packed
183+
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim1);
184+
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim2);
185+
VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim1);
186+
VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim2);
187+
182188
// Check that the concat dim is not one of the reduction dims
183189
if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
184190
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1);
@@ -232,13 +238,12 @@ void add_reduce2d_node(
232238
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
233239
return add_reduce_node( \
234240
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
235-
}
236-
if (dims_list.size() == 2) { \
241+
} \
242+
if (dims_list.size() == 2) { \
237243
return add_reduce2d_node( \
238244
graph, args[0], args[1], args[out_arg_idx], #op_name); \
239-
}
240-
VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \
241245
} \
246+
VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \
242247
}
243248

244249
DEFINE_REDUCE_FN(sum, 4)

0 commit comments

Comments
 (0)