@@ -40,7 +40,8 @@ void resize_reduce2d_node(
4040 vTensorPtr in = graph->get_tensor (args[1 ].refs [0 ]);
4141
4242 // Extract the dimensions to reduce over
43- const std::vector<int64_t > dims_list = graph->extract_int_or_symint_list (resize_args.at (0 ));
43+ const std::vector<int64_t > dims_list =
44+ graph->extract_int_or_symint_list (resize_args.at (0 ));
4445 int32_t reduce_dim1_nchw = dims_list[0 ];
4546 int32_t reduce_dim2_nchw = dims_list[1 ];
4647
@@ -161,24 +162,25 @@ void add_reduce2d_node(
161162 const ValueRef dims_ref,
162163 const ValueRef out,
163164 const std::string& op_name) {
164-
165165 VK_CHECK_COND (
166166 !graph.is_buffer_storage (in) && !graph.is_buffer_storage (out),
167167 " Vulkan reduction only supports texture storage" );
168168
169169 const int64_t ndim = graph.dim_of (in);
170-
170+
171171 // Extract the two dimensions to reduce over
172- const std::vector<int64_t > dims_list = graph.extract_int_or_symint_list (dims_ref);
173- VK_CHECK_COND (dims_list.size () == 2 , " reduce2d requires exactly 2 dimensions" );
174-
172+ const std::vector<int64_t > dims_list =
173+ graph.extract_int_or_symint_list (dims_ref);
174+ VK_CHECK_COND (
175+ dims_list.size () == 2 , " reduce2d requires exactly 2 dimensions" );
176+
175177 int32_t reduce_dim1 = normalize (dims_list[0 ], ndim);
176178 int32_t reduce_dim2 = normalize (dims_list[1 ], ndim);
177-
179+
178180 // Convert to WHCN format
179181 reduce_dim1 = nchw_dim_to_whcn_dim (reduce_dim1, ndim);
180182 reduce_dim2 = nchw_dim_to_whcn_dim (reduce_dim2, ndim);
181-
183+
182184 // Check that none of the reduction dims are packed
183185 VK_CHECK_COND (graph.packed_dim_of (in) != reduce_dim1);
184186 VK_CHECK_COND (graph.packed_dim_of (in) != reduce_dim2);
@@ -193,7 +195,7 @@ void add_reduce2d_node(
193195 VK_CHECK_COND (graph.concat_dim_of (out) != reduce_dim2);
194196 }
195197
196- std::string kernel_name = op_name + " 2d" ; // Add "2d" suffix
198+ std::string kernel_name = op_name + " 2d" ; // Add "2d" suffix
197199 kernel_name.reserve (kShaderNameReserve );
198200 add_dtype_suffix (kernel_name, graph.dtype_of (out));
199201
@@ -206,8 +208,10 @@ void add_reduce2d_node(
206208 }
207209 }
208210
209- const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int (reduce_dim1);
210- const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int (reduce_dim2);
211+ const ValueRef reduce_dim1_whcn_ref =
212+ graph.get_or_add_value_for_int (reduce_dim1);
213+ const ValueRef reduce_dim2_whcn_ref =
214+ graph.get_or_add_value_for_int (reduce_dim2);
211215 const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int (group_dim);
212216
213217 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
@@ -224,15 +228,18 @@ void add_reduce2d_node(
224228 // Specialization Constants
225229 {graph.packed_dim_of (out), reduce_dim1, reduce_dim2, group_dim},
226230 // Resize Args
227- {dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref},
231+ {dims_ref,
232+ reduce_dim1_whcn_ref,
233+ reduce_dim2_whcn_ref,
234+ group_dim_whcn_ref},
228235 // Resizing Logic
229236 resize_reduce2d_node));
230237}
231238
232239#define DEFINE_REDUCE_FN (op_name, out_arg_idx ) \
233240 void op_name (ComputeGraph& graph, const std::vector<ValueRef>& args) { \
234241 const std::vector<int64_t > dims_list = \
235- graph.extract_int_or_symint_list (args[1 ]); \
242+ graph.extract_int_or_symint_list (args[1 ]); \
236243 if (dims_list.size () == 1 ) { \
237244 const int64_t dim_val = dims_list.at (0 ); \
238245 const ValueRef dim_ref = graph.get_or_add_value_for_int (dim_val); \
0 commit comments