@@ -40,34 +40,52 @@ void add_transfer_copy_node(
4040
4141 int64_t dim_whcn = nchw_dim_to_whcn_dim (dim, ndim);
4242
43+ struct TransferParams {
44+ int32_t dim;
45+ int32_t index_or_start_ref;
46+ int32_t step_ref;
47+ } transfer_params{static_cast <int32_t >(dim_whcn), 0 , 0 };
48+
49+ const bool param_is_scalar = graph.is_scalar_or_none (index_or_start_ref) &&
50+ (transfer_type == TransferType::SELECT ||
51+ graph.is_scalar_or_none (step_ref));
52+
4353 vkapi::ParamsBindList param_buffers;
44- if (transfer_type == TransferType::SELECT) {
45- param_buffers = {
46- graph.get_or_create_int_param_buffer (index_or_start_ref, 0 )};
47- } else { // TransferType::SLICE
48- param_buffers = {
49- graph.get_or_create_int_param_buffer (index_or_start_ref, 0 ),
50- graph.get_or_create_int_param_buffer (step_ref, 1 )};
54+ if (!param_is_scalar) {
55+ if (transfer_type == TransferType::SELECT) {
56+ param_buffers = {
57+ graph.get_or_create_int_param_buffer (index_or_start_ref, 0 )};
58+ } else { // TransferType::SLICE
59+ param_buffers = {
60+ graph.get_or_create_int_param_buffer (index_or_start_ref, 0 ),
61+ graph.get_or_create_int_param_buffer (step_ref, 1 )};
62+ }
63+ } else {
64+ transfer_params.index_or_start_ref =
65+ graph.extract_scalar_or <int32_t >(index_or_start_ref, 0 );
66+ if (transfer_type != TransferType::SELECT) {
67+ transfer_params.step_ref = graph.extract_scalar_or <int32_t >(step_ref, 1 );
68+ }
5169 }
5270
53- const struct TransferParams {
54- const int32_t dim;
55- } transfer_params{static_cast <int32_t >(dim_whcn)};
56-
5771 std::vector<PushConstantDataInfo> push_constants;
72+ push_constants.reserve (graph.is_buffer_storage (out) ? 5 : 3 );
5873
5974 if (graph.is_buffer_storage (out)) {
60- push_constants = {
61- graph.sizes_pc_of (in),
62- graph.strides_pc_of (out),
63- graph.strides_pc_of (in),
64- graph.numel_pc_of (out),
65- PushConstantDataInfo (&transfer_params, sizeof (transfer_params))};
75+ push_constants.emplace_back (graph.sizes_pc_of (in));
76+ push_constants.emplace_back (graph.strides_pc_of (out));
77+ push_constants.emplace_back (graph.strides_pc_of (in));
78+ push_constants.emplace_back (graph.numel_pc_of (out));
6679 } else {
67- push_constants = {
68- graph.sizes_pc_of (out),
69- graph.sizes_pc_of (in),
70- PushConstantDataInfo (&transfer_params, sizeof (transfer_params))};
80+ push_constants.emplace_back (graph.sizes_pc_of (out));
81+ push_constants.emplace_back (graph.sizes_pc_of (in));
82+ }
83+
84+ if (param_is_scalar) {
85+ push_constants.emplace_back (&transfer_params, sizeof (transfer_params));
86+ } else {
87+ push_constants.emplace_back (
88+ &transfer_params.dim , sizeof (transfer_params.dim ));
7189 }
7290
7391 vkapi::SpecVarList spec_vars = {
@@ -82,6 +100,9 @@ void add_transfer_copy_node(
82100 } else { // TransferType::SLICE
83101 kernel_name = " slice" ;
84102 }
103+ if (!param_is_scalar) {
104+ kernel_name += " _ubo" ;
105+ }
85106 add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
86107 add_dtype_suffix (kernel_name, graph.dtype_of (out));
87108
0 commit comments