@@ -407,27 +407,68 @@ void add_conv2d_node(
407407 wg_size = {wg_size[0 ] * wg_size[1 ] * wg_size[2 ], 1 , 1 };
408408 }
409409
410- graph.execute_nodes ().emplace_back (new DispatchNode (
411- graph,
412- shader,
413- wg_size,
414- graph.create_local_wg_size (wg_size),
415- // Inputs and Outputs
416- {{out, vkapi::MemoryAccessType::WRITE},
417- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
418- // Shader params buffers
419- {
420- t_out->logical_limits_ubo (),
421- t_in->sizes_ubo (),
422- graph.create_params_buffer (kernel_params),
423- graph.create_params_buffer (extra_params),
424- graph.create_params_buffer (out_params),
425- },
426- // Specialization Constants
427- {},
428- // Resizing Logic
429- resize_conv2d_node,
430- {weight_data, stride, padding, dilation, transposed, output_padding}));
410+ if (method == Conv2dMethod::Pointwise) {
411+ const utils::ivec4 kernel_param_size_stride = {
412+ kernel_params.kernel_size [0 ],
413+ kernel_params.kernel_size [1 ],
414+ kernel_params.stride [0 ],
415+ kernel_params.stride [1 ]};
416+
417+ const utils::ivec4 kernel_param_pad_dial = {
418+ kernel_params.padding [0 ],
419+ kernel_params.padding [1 ],
420+ kernel_params.dilation [0 ],
421+ kernel_params.dilation [1 ]};
422+
423+ graph.execute_nodes ().emplace_back (new DispatchNode (
424+ graph,
425+ shader,
426+ wg_size,
427+ graph.create_local_wg_size (wg_size),
428+ // Inputs and Outputs
429+ {{out, vkapi::MemoryAccessType::WRITE},
430+ {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431+ // Shader params buffers
432+ {},
433+ // Specialization Constants
434+ {},
435+ // Resizing Logic
436+ resize_conv2d_node,
437+ {weight_data, stride, padding, dilation, transposed, output_padding},
438+ {
439+ graph.logical_limits_pc_of (out),
440+ graph.sizes_pc_of (in),
441+ PushConstantDataInfo (
442+ &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
443+ PushConstantDataInfo (
444+ &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
445+ PushConstantDataInfo (
446+ &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
447+ PushConstantDataInfo (&out_params, sizeof (out_params)),
448+ }));
449+ } else {
450+ graph.execute_nodes ().emplace_back (new DispatchNode (
451+ graph,
452+ shader,
453+ wg_size,
454+ graph.create_local_wg_size (wg_size),
455+ // Inputs and Outputs
456+ {{out, vkapi::MemoryAccessType::WRITE},
457+ {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458+ // Shader params buffers
459+ {
460+ t_out->logical_limits_ubo (),
461+ t_in->sizes_ubo (),
462+ graph.create_params_buffer (kernel_params),
463+ graph.create_params_buffer (extra_params),
464+ graph.create_params_buffer (out_params),
465+ },
466+ // Specialization Constants
467+ {},
468+ // Resizing Logic
469+ resize_conv2d_node,
470+ {weight_data, stride, padding, dilation, transposed, output_padding}));
471+ }
431472}
432473
433474void add_conv1d_node (
0 commit comments