@@ -407,27 +407,68 @@ void add_conv2d_node(
407407    wg_size = {wg_size[0 ] * wg_size[1 ] * wg_size[2 ], 1 , 1 };
408408  }
409409
410-   graph.execute_nodes ().emplace_back (new  DispatchNode (
411-       graph,
412-       shader,
413-       wg_size,
414-       graph.create_local_wg_size (wg_size),
415-       //  Inputs and Outputs
416-       {{out, vkapi::MemoryAccessType::WRITE},
417-        {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
418-       //  Shader params buffers
419-       {
420-           t_out->logical_limits_ubo (),
421-           t_in->sizes_ubo (),
422-           graph.create_params_buffer (kernel_params),
423-           graph.create_params_buffer (extra_params),
424-           graph.create_params_buffer (out_params),
425-       },
426-       //  Specialization Constants
427-       {},
428-       //  Resizing Logic
429-       resize_conv2d_node,
430-       {weight_data, stride, padding, dilation, transposed, output_padding}));
410+   if  (method == Conv2dMethod::Pointwise) {
411+     const  utils::ivec4 kernel_param_size_stride = {
412+         kernel_params.kernel_size [0 ],
413+         kernel_params.kernel_size [1 ],
414+         kernel_params.stride [0 ],
415+         kernel_params.stride [1 ]};
416+ 
417+     const  utils::ivec4 kernel_param_pad_dial = {
418+         kernel_params.padding [0 ],
419+         kernel_params.padding [1 ],
420+         kernel_params.dilation [0 ],
421+         kernel_params.dilation [1 ]};
422+ 
423+     graph.execute_nodes ().emplace_back (new  DispatchNode (
424+         graph,
425+         shader,
426+         wg_size,
427+         graph.create_local_wg_size (wg_size),
428+         //  Inputs and Outputs
429+         {{out, vkapi::MemoryAccessType::WRITE},
430+          {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431+         //  Shader params buffers
432+         {},
433+         //  Specialization Constants
434+         {},
435+         //  Resizing Logic
436+         resize_conv2d_node,
437+         {weight_data, stride, padding, dilation, transposed, output_padding},
438+         {
439+             graph.logical_limits_pc_of (out),
440+             graph.sizes_pc_of (in),
441+             PushConstantDataInfo (
442+                 &kernel_param_size_stride, sizeof (kernel_param_size_stride)),
443+             PushConstantDataInfo (
444+                 &kernel_param_pad_dial, sizeof (kernel_param_pad_dial)),
445+             PushConstantDataInfo (
446+                 &extra_params, sizeof (extra_params), sizeof (utils::ivec4)),
447+             PushConstantDataInfo (&out_params, sizeof (out_params)),
448+         }));
449+   } else  {
450+     graph.execute_nodes ().emplace_back (new  DispatchNode (
451+         graph,
452+         shader,
453+         wg_size,
454+         graph.create_local_wg_size (wg_size),
455+         //  Inputs and Outputs
456+         {{out, vkapi::MemoryAccessType::WRITE},
457+          {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458+         //  Shader params buffers
459+         {
460+             t_out->logical_limits_ubo (),
461+             t_in->sizes_ubo (),
462+             graph.create_params_buffer (kernel_params),
463+             graph.create_params_buffer (extra_params),
464+             graph.create_params_buffer (out_params),
465+         },
466+         //  Specialization Constants
467+         {},
468+         //  Resizing Logic
469+         resize_conv2d_node,
470+         {weight_data, stride, padding, dilation, transposed, output_padding}));
471+   }
431472}
432473
433474void  add_conv1d_node (
0 commit comments