Skip to content

Commit f891c4b

Browse files
committed
Fix style in Reduction.cpp and op_registry.py
1 parent df13204 commit f891c4b

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

backends/vulkan/op_registry.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,13 +385,22 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
385385
if memory_layout is not None:
386386
for dim in dim_list:
387387
# For WIDTH_PACKED layout, dimension 3 (W) is packed
388-
if memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED and dim == 3:
388+
if (
389+
memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED
390+
and dim == 3
391+
):
389392
return False
390393
# For HEIGHT_PACKED layout, dimension 2 (H) is packed
391-
elif memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED and dim == 2:
394+
elif (
395+
memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED
396+
and dim == 2
397+
):
392398
return False
393399
# For CHANNELS_PACKED layout, dimension 1 (C) is packed
394-
elif memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED and dim == 1:
400+
elif (
401+
memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED
402+
and dim == 1
403+
):
395404
return False
396405
except (AssertionError, KeyError, AttributeError):
397406
# If we can't get memory layout information, we'll assume the dims aren't packed

backends/vulkan/runtime/graph/ops/impl/Reduce.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ void resize_reduce2d_node(
4040
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
4141

4242
// Extract the dimensions to reduce over
43-
const std::vector<int64_t> dims_list = graph->extract_int_or_symint_list(resize_args.at(0));
43+
const std::vector<int64_t> dims_list =
44+
graph->extract_int_or_symint_list(resize_args.at(0));
4445
int32_t reduce_dim1_nchw = dims_list[0];
4546
int32_t reduce_dim2_nchw = dims_list[1];
4647

@@ -161,24 +162,25 @@ void add_reduce2d_node(
161162
const ValueRef dims_ref,
162163
const ValueRef out,
163164
const std::string& op_name) {
164-
165165
VK_CHECK_COND(
166166
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
167167
"Vulkan reduction only supports texture storage");
168168

169169
const int64_t ndim = graph.dim_of(in);
170-
170+
171171
// Extract the two dimensions to reduce over
172-
const std::vector<int64_t> dims_list = graph.extract_int_or_symint_list(dims_ref);
173-
VK_CHECK_COND(dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");
174-
172+
const std::vector<int64_t> dims_list =
173+
graph.extract_int_or_symint_list(dims_ref);
174+
VK_CHECK_COND(
175+
dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");
176+
175177
int32_t reduce_dim1 = normalize(dims_list[0], ndim);
176178
int32_t reduce_dim2 = normalize(dims_list[1], ndim);
177-
179+
178180
// Convert to WHCN format
179181
reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim);
180182
reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim);
181-
183+
182184
// Check that none of the reduction dims are packed
183185
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim1);
184186
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim2);
@@ -193,7 +195,7 @@ void add_reduce2d_node(
193195
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2);
194196
}
195197

196-
std::string kernel_name = op_name + "2d"; // Add "2d" suffix
198+
std::string kernel_name = op_name + "2d"; // Add "2d" suffix
197199
kernel_name.reserve(kShaderNameReserve);
198200
add_dtype_suffix(kernel_name, graph.dtype_of(out));
199201

@@ -206,8 +208,10 @@ void add_reduce2d_node(
206208
}
207209
}
208210

209-
const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int(reduce_dim1);
210-
const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int(reduce_dim2);
211+
const ValueRef reduce_dim1_whcn_ref =
212+
graph.get_or_add_value_for_int(reduce_dim1);
213+
const ValueRef reduce_dim2_whcn_ref =
214+
graph.get_or_add_value_for_int(reduce_dim2);
211215
const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim);
212216

213217
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
@@ -224,15 +228,18 @@ void add_reduce2d_node(
224228
// Specialization Constants
225229
{graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim},
226230
// Resize Args
227-
{dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref},
231+
{dims_ref,
232+
reduce_dim1_whcn_ref,
233+
reduce_dim2_whcn_ref,
234+
group_dim_whcn_ref},
228235
// Resizing Logic
229236
resize_reduce2d_node));
230237
}
231238

232239
#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
233240
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
234241
const std::vector<int64_t> dims_list = \
235-
graph.extract_int_or_symint_list(args[1]); \
242+
graph.extract_int_or_symint_list(args[1]); \
236243
if (dims_list.size() == 1) { \
237244
const int64_t dim_val = dims_list.at(0); \
238245
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \

0 commit comments

Comments
 (0)