@@ -32,6 +32,24 @@ void resize_reduce_node(
3232 out->virtual_resize (new_sizes);
3333}
3434
35+ void resize_reduce2d_node (
36+ ComputeGraph* graph,
37+ const std::vector<ArgGroup>& args,
38+ const std::vector<ValueRef>& resize_args) {
39+ vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
40+ vTensorPtr in = graph->get_tensor (args[1 ].refs [0 ]);
41+
42+ // Extract the dimensions to reduce over
43+ const std::vector<int64_t > dims_list = graph->extract_int_or_symint_list (resize_args.at (0 ));
44+ int32_t reduce_dim1_nchw = dims_list[0 ];
45+ int32_t reduce_dim2_nchw = dims_list[1 ];
46+
47+ std::vector<int64_t > new_sizes = in->sizes ();
48+ new_sizes.at (normalize (reduce_dim1_nchw, new_sizes.size ())) = 1 ;
49+ new_sizes.at (normalize (reduce_dim2_nchw, new_sizes.size ())) = 1 ;
50+ out->virtual_resize (new_sizes);
51+ }
52+
3553utils::uvec3 reduce_global_wg_size (
3654 ComputeGraph* graph,
3755 const vkapi::ShaderInfo& shader,
@@ -137,15 +155,89 @@ void add_reduce_node(
137155 resize_reduce_node));
138156}
139157
158+ void add_reduce2d_node (
159+ ComputeGraph& graph,
160+ const ValueRef in,
161+ const ValueRef dims_ref,
162+ const ValueRef out,
163+ const std::string& op_name) {
164+
165+ VK_CHECK_COND (
166+ !graph.is_buffer_storage (in) && !graph.is_buffer_storage (out),
167+ " Vulkan reduction only supports texture storage" );
168+
169+ const int64_t ndim = graph.dim_of (in);
170+
171+ // Extract the two dimensions to reduce over
172+ const std::vector<int64_t > dims_list = graph.extract_int_or_symint_list (dims_ref);
173+ VK_CHECK_COND (dims_list.size () == 2 , " reduce2d requires exactly 2 dimensions" );
174+
175+ int32_t reduce_dim1 = normalize (dims_list[0 ], ndim);
176+ int32_t reduce_dim2 = normalize (dims_list[1 ], ndim);
177+
178+ // Convert to WHCN format
179+ reduce_dim1 = nchw_dim_to_whcn_dim (reduce_dim1, ndim);
180+ reduce_dim2 = nchw_dim_to_whcn_dim (reduce_dim2, ndim);
181+
182+ // Check that the concat dim is not one of the reduction dims
183+ if (graph.dim_of (in) == 4 && graph.size_at <int >(0 , in) > 1 ) {
184+ VK_CHECK_COND (graph.concat_dim_of (in) != reduce_dim1);
185+ VK_CHECK_COND (graph.concat_dim_of (in) != reduce_dim2);
186+ VK_CHECK_COND (graph.concat_dim_of (out) != reduce_dim1);
187+ VK_CHECK_COND (graph.concat_dim_of (out) != reduce_dim2);
188+ }
189+
190+ std::string kernel_name = op_name + " 2d" ; // Add "2d" suffix
191+ kernel_name.reserve (kShaderNameReserve );
192+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
193+
194+ // Calculate group_dim for specialization constants (use remaining dimension)
195+ int32_t group_dim = 0 ;
196+ for (int i = 0 ; i < 3 ; i++) {
197+ if (i != reduce_dim1 && i != reduce_dim2) {
198+ group_dim = i;
199+ break ;
200+ }
201+ }
202+
203+ const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int (reduce_dim1);
204+ const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int (reduce_dim2);
205+ const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int (group_dim);
206+
207+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
208+ graph,
209+ VK_KERNEL_FROM_STR (kernel_name),
210+ reduce_global_wg_size,
211+ reduce_local_wg_size,
212+ // Inputs and Outputs
213+ {{out, vkapi::kWrite }, {in, vkapi::kRead }},
214+ // Shader params buffers
215+ {graph.logical_limits_ubo (in), graph.sizes_ubo (in)},
216+ // Push Constants
217+ {},
218+ // Specialization Constants
219+ {graph.packed_dim_of (out), reduce_dim1, reduce_dim2, group_dim},
220+ // Resize Args
221+ {dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref},
222+ // Resizing Logic
223+ resize_reduce2d_node));
224+ }
225+
140226#define DEFINE_REDUCE_FN (op_name, out_arg_idx ) \
141227 void op_name (ComputeGraph& graph, const std::vector<ValueRef>& args) { \
142228 const std::vector<int64_t > dims_list = \
143- graph.extract_int_or_symint_list (args[1 ]); \
144- VK_CHECK_COND (dims_list.size () == 1 ); \
145- const int64_t dim_val = dims_list.at (0 ); \
146- const ValueRef dim_ref = graph.get_or_add_value_for_int (dim_val); \
147- return add_reduce_node ( \
148- graph, args[0 ], dim_ref, args[out_arg_idx], #op_name); \
229+ graph.extract_int_or_symint_list (args[1 ]); \
230+ if (dims_list.size () == 1 ) { \
231+ const int64_t dim_val = dims_list.at (0 ); \
232+ const ValueRef dim_ref = graph.get_or_add_value_for_int (dim_val); \
233+ return add_reduce_node ( \
234+ graph, args[0 ], dim_ref, args[out_arg_idx], #op_name); \
235+ } else if (dims_list.size () == 2 ) { \
236+ return add_reduce2d_node ( \
237+ graph, args[0 ], args[1 ], args[out_arg_idx], #op_name); \
238+ } else { \
239+ VK_CHECK_COND (false , " Only 1 or 2 dimensions supported" ); \
240+ } \
149241 }
150242
151243DEFINE_REDUCE_FN (sum, 4 )
0 commit comments