@@ -407,7 +407,9 @@ void add_conv2d_node(
407407 wg_size = {wg_size[0 ] * wg_size[1 ] * wg_size[2 ], 1 , 1 };
408408 }
409409
410- if (method == Conv2dMethod::Pointwise) {
410+ vkapi::ParamsBindList param_buffers;
411+ std::vector<PushConstantDataInfo> push_constants;
412+ if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
411413 const utils::ivec4 kernel_param_size_stride = {
412414 kernel_params.kernel_size [0 ],
413415 kernel_params.kernel_size [1 ],
@@ -420,55 +422,43 @@ void add_conv2d_node(
420422 kernel_params.dilation [0 ],
421423 kernel_params.dilation [1 ]};
422424
423- graph.execute_nodes ().emplace_back (new DispatchNode (
424- graph,
425- shader,
426- wg_size,
427- graph.create_local_wg_size (wg_size),
428- // Inputs and Outputs
429- {{out, vkapi::MemoryAccessType::WRITE},
430- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431- // Shader params buffers
432- {},
433- // Specialization Constants
434- {},
435- // Resizing Logic
436- resize_conv2d_node,
437- {weight_data, stride, padding, dilation, transposed, output_padding},
438- {
439- graph.logical_limits_pc_of (out),
440- graph.sizes_pc_of (in),
441- PushConstantDataInfo (
442- &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
443- PushConstantDataInfo (
444- &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
445- PushConstantDataInfo (
446- &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
447- PushConstantDataInfo (&out_params, sizeof (out_params)),
448- }));
425+ push_constants = {
426+ graph.logical_limits_pc_of (out),
427+ graph.sizes_pc_of (in),
428+ PushConstantDataInfo (
429+ &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
430+ PushConstantDataInfo (
431+ &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
432+ PushConstantDataInfo (
433+ &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
434+ PushConstantDataInfo (&out_params, sizeof (out_params)),
435+ };
449436 } else {
450- graph.execute_nodes ().emplace_back (new DispatchNode (
451- graph,
452- shader,
453- wg_size,
454- graph.create_local_wg_size (wg_size),
455- // Inputs and Outputs
456- {{out, vkapi::MemoryAccessType::WRITE},
457- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458- // Shader params buffers
459- {
460- t_out->logical_limits_ubo (),
461- t_in->sizes_ubo (),
462- graph.create_params_buffer (kernel_params),
463- graph.create_params_buffer (extra_params),
464- graph.create_params_buffer (out_params),
465- },
466- // Specialization Constants
467- {},
468- // Resizing Logic
469- resize_conv2d_node,
470- {weight_data, stride, padding, dilation, transposed, output_padding}));
437+ param_buffers = {
438+ t_out->logical_limits_ubo (),
439+ t_in->sizes_ubo (),
440+ graph.create_params_buffer (kernel_params),
441+ graph.create_params_buffer (extra_params),
442+ graph.create_params_buffer (out_params),
443+ };
471444 }
445+
446+ graph.execute_nodes ().emplace_back (new DispatchNode (
447+ graph,
448+ shader,
449+ wg_size,
450+ graph.create_local_wg_size (wg_size),
451+ // Inputs and Outputs
452+ {{out, vkapi::MemoryAccessType::WRITE},
453+ {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
454+ // Shader params buffers
455+ std::move (param_buffers),
456+ // Specialization Constants
457+ {},
458+ // Resizing Logic
459+ resize_conv2d_node,
460+ {weight_data, stride, padding, dilation, transposed, output_padding},
461+ std::move (push_constants)));
472462}
473463
474464void add_conv1d_node (
0 commit comments