@@ -407,7 +407,9 @@ void add_conv2d_node(
407407    wg_size = {wg_size[0 ] * wg_size[1 ] * wg_size[2 ], 1 , 1 };
408408  }
409409
410-   if  (method == Conv2dMethod::Pointwise) {
410+   vkapi::ParamsBindList param_buffers;
411+   std::vector<PushConstantDataInfo> push_constants;
412+   if  (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
411413    const  utils::ivec4 kernel_param_size_stride = {
412414        kernel_params.kernel_size [0 ],
413415        kernel_params.kernel_size [1 ],
@@ -420,55 +422,43 @@ void add_conv2d_node(
420422        kernel_params.dilation [0 ],
421423        kernel_params.dilation [1 ]};
422424
423-     graph.execute_nodes ().emplace_back (new  DispatchNode (
424-         graph,
425-         shader,
426-         wg_size,
427-         graph.create_local_wg_size (wg_size),
428-         //  Inputs and Outputs
429-         {{out, vkapi::MemoryAccessType::WRITE},
430-          {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431-         //  Shader params buffers
432-         {},
433-         //  Specialization Constants
434-         {},
435-         //  Resizing Logic
436-         resize_conv2d_node,
437-         {weight_data, stride, padding, dilation, transposed, output_padding},
438-         {
439-             graph.logical_limits_pc_of (out),
440-             graph.sizes_pc_of (in),
441-             PushConstantDataInfo (
442-                 &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
443-             PushConstantDataInfo (
444-                 &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
445-             PushConstantDataInfo (
446-                 &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
447-             PushConstantDataInfo (&out_params, sizeof (out_params)),
448-         }));
425+     push_constants = {
426+         graph.logical_limits_pc_of (out),
427+         graph.sizes_pc_of (in),
428+         PushConstantDataInfo (
429+             &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
430+         PushConstantDataInfo (
431+             &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
432+         PushConstantDataInfo (
433+             &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
434+         PushConstantDataInfo (&out_params, sizeof (out_params)),
435+     };
449436  } else  {
450-     graph.execute_nodes ().emplace_back (new  DispatchNode (
451-         graph,
452-         shader,
453-         wg_size,
454-         graph.create_local_wg_size (wg_size),
455-         //  Inputs and Outputs
456-         {{out, vkapi::MemoryAccessType::WRITE},
457-          {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458-         //  Shader params buffers
459-         {
460-             t_out->logical_limits_ubo (),
461-             t_in->sizes_ubo (),
462-             graph.create_params_buffer (kernel_params),
463-             graph.create_params_buffer (extra_params),
464-             graph.create_params_buffer (out_params),
465-         },
466-         //  Specialization Constants
467-         {},
468-         //  Resizing Logic
469-         resize_conv2d_node,
470-         {weight_data, stride, padding, dilation, transposed, output_padding}));
437+     param_buffers = {
438+         t_out->logical_limits_ubo (),
439+         t_in->sizes_ubo (),
440+         graph.create_params_buffer (kernel_params),
441+         graph.create_params_buffer (extra_params),
442+         graph.create_params_buffer (out_params),
443+     };
471444  }
445+ 
446+   graph.execute_nodes ().emplace_back (new  DispatchNode (
447+       graph,
448+       shader,
449+       wg_size,
450+       graph.create_local_wg_size (wg_size),
451+       //  Inputs and Outputs
452+       {{out, vkapi::MemoryAccessType::WRITE},
453+        {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
454+       //  Shader params buffers
455+       param_buffers,
456+       //  Specialization Constants
457+       {},
458+       //  Resizing Logic
459+       resize_conv2d_node,
460+       {weight_data, stride, padding, dilation, transposed, output_padding},
461+       push_constants));
472462}
473463
474464void  add_conv1d_node (
0 commit comments