diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 9bf317b9067..52194ea82e3 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -285,7 +285,8 @@ ValueRef ComputeGraph::add_tensor_like( ValueRef ComputeGraph::add_tensor_like( const ValueRef idx, const utils::GPUMemoryLayout memory_layout) { - return add_tensor(sizes_of(idx), dtype_of(idx), memory_layout); + return add_tensor( + sizes_of(idx), dtype_of(idx), storage_type_of(idx), memory_layout); } ValueRef ComputeGraph::add_tensor( diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp index 43337a6cb0b..ec7b6c2fc12 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -18,9 +18,10 @@ namespace vkcompute { -ValueRef prepack_arg( +ValueRef check_and_prepack_arg( ComputeGraph& graph, ValueRef arg_ref, + const utils::StorageType stype, int64_t num_channels, const std::string& debug_name) { VK_CHECK_COND( @@ -33,7 +34,7 @@ ValueRef prepack_arg( // batch_norm's param are broadcasted on the channel dimension. // In this implementation, we pack the weights along the x dimension, and // in the shader, we lookup using the along the x. - return prepack_if_tensor_ref(graph, arg_ref, utils::kWidthPacked); + return prepack_standard(graph, arg_ref, stype, utils::kWidthPacked); } void add_native_batch_norm_node( @@ -51,22 +52,26 @@ void add_native_batch_norm_node( VK_CHECK_COND(in_sizes.size() == 4, "BatchNorm only support 4d tensor"); VK_CHECK_COND(out_sizes.size() == 4, "BatchNorm only support 4d tensor"); + // Only the first element of the return value is propagated. The remaining 2 + // elements are zero-size dummy tensor. + ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0); + + utils::StorageType stype = graph.storage_type_of(out_ref); + int64_t num_channels = dim_at(in_sizes); - ValueRef arg_weight = prepack_arg(graph, weight_ref, num_channels, "weight"); - ValueRef arg_bias = prepack_arg(graph, bias_ref, num_channels, "bias"); - ValueRef arg_mean = prepack_arg(graph, mean_ref, num_channels, "mean"); - ValueRef arg_var = prepack_arg(graph, var_ref, num_channels, "var"); + ValueRef arg_weight = + check_and_prepack_arg(graph, weight_ref, stype, num_channels, "weight"); + ValueRef arg_bias = + check_and_prepack_arg(graph, bias_ref, stype, num_channels, "bias"); + ValueRef arg_mean = + check_and_prepack_arg(graph, mean_ref, stype, num_channels, "mean"); + ValueRef arg_var = + check_and_prepack_arg(graph, var_ref, stype, num_channels, "var"); float epsilon = graph.extract_scalar(eps_ref); vTensorPtr t_in = graph.get_tensor(in_ref); - // Only the first element of the return value is propagated. The remaining 2 - // elements are zero-size dummy tensor. - const auto out_tuple_val = graph.get_value_list(out_tuple_ref); - - ValueRef out_ref = out_tuple_val->at(0); - VK_CHECK_COND(!graph.val_is_tref(out_ref), "Output should not be tref"); vTensorPtr t_out = graph.get_tensor(out_ref); diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index c701597984b..c055431a84b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -51,9 +51,8 @@ void add_binary_op_node( const ValueRef alpha, const ValueRef out, const std::string& op_name) { - ValueRef arg1 = prepack_if_tensor_ref(graph, in1); - ValueRef arg2 = - prepack_if_tensor_ref(graph, in2, graph.estimate_memory_layout_of(arg1)); + ValueRef arg1 = prepack_standard_like(graph, in1, out, true); + ValueRef arg2 = prepack_standard_like(graph, in2, out, true); vTensorPtr t_in1 = graph.get_tensor(arg1); vTensorPtr t_in2 = graph.get_tensor(arg2); diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 4afaef04d8c..43568622f84 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -304,7 +304,7 @@ utils::uvec3 create_conv2d_global_wg_size( void add_conv2d_node( ComputeGraph& graph, const ValueRef in, - const ValueRef weight, + const ValueRef weight_data, const ValueRef bias, const ValueRef stride, const ValueRef padding, @@ -330,19 +330,18 @@ void add_conv2d_node( const int64_t groups_val = graph.get_int(groups); const Conv2dMethod method = - get_conv2d_method(graph, weight, groups_val, transposed_val); + get_conv2d_method(graph, weight_data, groups_val, transposed_val); - ValueRef arg_in = prepack_if_tensor_ref(graph, in); - ValueRef arg_weight = prepack_weights(graph, weight, method); + ValueRef arg_weight = prepack_weights(graph, weight_data, method); ValueRef arg_bias = prepack_biases( graph, bias, - weight, + weight_data, transposed_val, /* storage_type = */ utils::kTexture2D, /* memory_layout = */ utils::kWidthPacked); - vTensorPtr t_in = graph.get_tensor(arg_in); + vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); if (t_in->sizes().at(0) > 1) { VK_THROW("conv2d: input batch size > 1 is not supported yet!"); @@ -351,20 +350,25 @@ void add_conv2d_node( Kernel2dParams kernel_params = create_kernel2d_params( graph, - weight, + weight_data, /*kernel_size_only = */ false, stride, padding, dilation); Conv2dParams extra_params = - create_conv2d_params(graph, weight, kernel_params, transposed_val); + create_conv2d_params(graph, weight_data, kernel_params, transposed_val); OutputParams out_params = {out_min_val, out_max_val}; check_conv2d_params(kernel_params, transposed_val); vkapi::ShaderInfo shader = get_conv2d_shader( - graph, *t_out, /*prepack_weights = */ false, method, weight, clamp_out); + graph, + *t_out, + /*prepack_weights = */ false, + method, + weight_data, + clamp_out); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -373,7 +377,7 @@ void add_conv2d_node( graph.create_local_wg_size(out), // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, - {{arg_in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, + {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, // Shader params buffers { t_out->logical_limits_ubo(), @@ -386,7 +390,7 @@ void add_conv2d_node( {}, // Resizing Logic resize_conv2d_node, - {weight, stride, padding, dilation, transposed, output_padding})); + {weight_data, stride, padding, dilation, transposed, output_padding})); } void add_conv1d_node( @@ -402,9 +406,8 @@ void add_conv1d_node( const ValueRef out_max, const ValueRef out, const bool clamp_out) { - ValueRef arg_in = prepack_if_tensor_ref(graph, in); - ValueRef arg_weight = - prepack_if_tensor_ref(graph, weight, utils::kWidthPacked); + ValueRef arg_weight = prepack_standard( + graph, weight, graph.storage_type_of(out), utils::kWidthPacked); ValueRef arg_bias = prepack_biases( graph, bias, @@ -422,7 +425,7 @@ void add_conv1d_node( out_max_val = graph.extract_scalar(out_max); } - vTensorPtr t_in = graph.get_tensor(arg_in); + vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_weight = graph.get_tensor(arg_weight); vTensorPtr t_bias = graph.get_tensor(arg_bias); vTensorPtr t_out = graph.get_tensor(out); @@ -471,7 +474,7 @@ void add_conv1d_node( local_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, - {{arg_in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, + {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, // Shader params buffers { t_out->logical_limits_ubo(), diff --git a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp index 0658d21108b..beaeed59baa 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp @@ -57,9 +57,9 @@ void add_embedding_node( } void embedding(ComputeGraph& graph, const std::vector& args) { - ValueRef weight = prepack_if_tensor_ref(graph, args[0]); - ValueRef in = prepack_if_tensor_ref(graph, args[1]); + ValueRef in = args[1]; ValueRef out = args[5]; + ValueRef weight = prepack_standard_like(graph, args[0], out); add_embedding_node(graph, weight, in, out); } diff --git a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp index 0ff217b4f89..1f56d3c45d3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp @@ -108,9 +108,9 @@ int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) { } void index_select(ComputeGraph& graph, const std::vector& args) { - ValueRef in = prepack_if_tensor_ref(graph, args[0]); + ValueRef in = args[0]; ValueRef dim_ref = args[1]; - ValueRef idx = prepack_if_tensor_ref(graph, args[2]); + ValueRef idx = args[2]; ValueRef out = args[3]; const int64_t dim_idx = get_dim_idx(graph, in, dim_ref); diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 44a4acacc55..74afce1abe3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -94,8 +94,11 @@ void add_addmm_naive_node( const ValueRef out, const Params& params, const ValueRef mat2_is_transposed) { - ValueRef self = prepack_if_tensor_ref(graph, self_data, utils::kWidthPacked); - ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked); + utils::StorageType stype = graph.storage_type_of(out); + ValueRef self = prepack_standard( + graph, self_data, stype, utils::kWidthPacked, /*passthrough = */ true); + ValueRef mat2 = prepack_standard( + graph, mat2_data, stype, utils::kHeightPacked, /*passthrough = */ true); std::string kernel_name = graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive"; @@ -145,9 +148,11 @@ void add_addmm_optimized_node( const ValueRef out, const Params& params, const ValueRef mat2_is_transposed) { - ValueRef self = - prepack_if_tensor_ref(graph, self_data, utils::kChannelsPacked); - ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked); + utils::StorageType stype = graph.storage_type_of(out); + ValueRef self = prepack_standard( + graph, self_data, stype, utils::kChannelsPacked, /*passthrough=*/true); + ValueRef mat2 = prepack_standard( + graph, mat2_data, stype, utils::kHeightPacked, /*passthrough=*/true); // Ensure mat1 is width packed ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); @@ -276,8 +281,8 @@ void linear(ComputeGraph& graph, const std::vector& args) { ValueRef weight_data = args.at(1); ValueRef bias = args.at(2); ValueRef out = args.at(3); - ValueRef weight = - prepack_if_tensor_ref(graph, weight_data, utils::kWidthPacked); + ValueRef weight = prepack_standard( + graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked); ValueRef mat2_is_transposed = graph.add_scalar(true); if (graph.val_is_none(bias)) { return add_matmul_node(graph, input, weight, out, mat2_is_transposed); diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index ca55d6eeb0c..71e9033cec2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -62,7 +62,12 @@ void add_matmul_naive_buffer_node( const ValueRef mat2_data, const ValueRef out, const ValueRef mat2_is_transposed) { - ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked); + ValueRef mat2 = prepack_standard( + graph, + mat2_data, + graph.storage_type_of(out), + utils::kHeightPacked, + /*passthrough = */ true); std::string kernel_name = "matmul_naive_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); @@ -103,7 +108,12 @@ void add_matmul_naive_texture3d_node( const ValueRef mat2_data, const ValueRef out, const ValueRef mat2_is_transposed) { - ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked); + ValueRef mat2 = prepack_standard( + graph, + mat2_data, + graph.storage_type_of(out), + utils::kHeightPacked, + /*passthrough = */ true); std::string kernel_name = graph.get_bool(mat2_is_transposed) ? "matmul_transposed_naive" @@ -146,7 +156,12 @@ void add_matmul_optimized_node( const ValueRef mat2_data, const ValueRef out, const ValueRef mat2_is_transposed) { - ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked); + ValueRef mat2 = prepack_standard( + graph, + mat2_data, + graph.storage_type_of(out), + utils::kHeightPacked, + /*passthrough = */ true); // Ensure mat1 is width packed ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 8e2f296f3c9..0e30d8a2c6e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -57,8 +57,8 @@ void add_native_layer_norm_node( ComputeGraph& graph, const ValueRef in, const ValueRef normalized_shape, - const ValueRef weight, - const ValueRef bias, + const ValueRef weight_data, + const ValueRef bias_data, const ValueRef eps, const ValueRef out) { const auto normalized_shape_dim = @@ -67,19 +67,16 @@ void add_native_layer_norm_node( VK_THROW("native_layer_norm only supports normalized_shape with dim == 1"); } - if (graph.val_is_none(weight)) { + if (graph.val_is_none(weight_data)) { VK_THROW("native_layer_norm requires weight to be non-None"); } - if (graph.val_is_none(bias)) { + if (graph.val_is_none(bias_data)) { VK_THROW("native_layer_norm requires bias to be non-None"); } - ValueRef arg_in = prepack_if_tensor_ref(graph, in); - ValueRef arg_weight = prepack_if_tensor_ref( - graph, weight, graph.estimate_memory_layout_of(arg_in)); - ValueRef arg_bias = prepack_if_tensor_ref( - graph, bias, graph.estimate_memory_layout_of(arg_in)); + ValueRef arg_weight = prepack_standard_like(graph, weight_data, in); + ValueRef arg_bias = prepack_standard_like(graph, bias_data, in); const auto out_val = graph.get_value_list(out); vTensorPtr t_out = graph.get_tensor(out_val->at(0)); @@ -107,7 +104,7 @@ void add_native_layer_norm_node( // Inputs and Outputs {{{out_val->at(0), out_val->at(1), out_val->at(2)}, vkapi::MemoryAccessType::WRITE}, - {{arg_in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, + {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, // Shader params buffers {t_out->logical_limits_ubo(), t_out->sizes_ubo(), diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index bba298202f8..b7015d2b1a0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -71,8 +71,7 @@ void add_max_pool2d_node( const ValueRef dilation, const ValueRef ceil_mode, const ValueRef out) { - ValueRef arg = prepack_if_tensor_ref(graph, in); - vTensorPtr t_in = graph.get_tensor(arg); + vTensorPtr t_in = graph.get_tensor(in); const auto out_val = graph.get_value_list(out); vTensorPtr t_out = graph.get_tensor(out_val->at(0)); @@ -100,7 +99,7 @@ void add_max_pool2d_node( local_size, // Inputs and Outputs {{{out_val->at(0), out_val->at(1)}, vkapi::MemoryAccessType::WRITE}, - {arg, vkapi::MemoryAccessType::READ}}, + {in, vkapi::MemoryAccessType::READ}}, // Shader params buffers { t_out->logical_limits_ubo(), @@ -149,8 +148,7 @@ void add_avg_pool2d_node( const ValueRef count_include_pad, const ValueRef divisor_override, const ValueRef out) { - ValueRef arg = prepack_if_tensor_ref(graph, in); - vTensorPtr t_in = graph.get_tensor(arg); + vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); check_pool2d_args(*t_in, *t_out); @@ -174,7 +172,7 @@ void add_avg_pool2d_node( local_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, - {arg, vkapi::MemoryAccessType::READ}}, + {in, vkapi::MemoryAccessType::READ}}, // Shader params buffers {t_out->logical_limits_ubo(), t_in->sizes_ubo(), diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 5642976b7fe..cb3bafbb81b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -81,10 +81,10 @@ void add_q_8w_linear_node( // Ensure out is packed correctly out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); } - ValueRef q_mat2 = - prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked); - ValueRef scales = - prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked); + ValueRef q_mat2 = prepack_standard( + graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked); + ValueRef scales = prepack_standard( + graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked); std::string kernel_name = "q_8w_linear"; kernel_name.reserve(kShaderNameReserve); @@ -146,10 +146,12 @@ void add_q_8w_linear_optimized_node( // Ensure out is packed correctly out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); } + + utils::StorageType stype = graph.storage_type_of(out); ValueRef q_mat2 = - prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked); + prepack_standard(graph, q_mat2_data, stype, utils::kWidthPacked); ValueRef scales = - prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked); + prepack_standard(graph, scales_data, stype, utils::kWidthPacked); std::string kernel_name = "q_8w_linear_optimized"; kernel_name.reserve(kShaderNameReserve); @@ -295,11 +297,13 @@ void add_q_4w_linear_node( utils::StorageType storage_type = graph.storage_type_of(out); - ValueRef mat2 = - prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); + ValueRef mat2 = prepack_direct_copy_buffer(graph, mat2_data); - ValueRef scales_and_zeros = - prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked); + ValueRef scales_and_zeros = prepack_standard( + graph, + scales_and_zeros_data, + graph.storage_type_of(out), + utils::kWidthPacked); std::string kernel_name = "q_4w_linear"; add_storage_type_suffix(kernel_name, storage_type); diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 28dd9d1e68b..2c462013513 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -223,9 +223,9 @@ void sdpa_with_kv_cache_impl( VK_CHECK_COND(graph.val_is_none(attn_mask)); const ValueRef k_cache = - prepack_if_tensor_ref(graph, k_cache_data, utils::kWidthPacked); + prepack_standard_like(graph, k_cache_data, q_projected); const ValueRef v_cache = - prepack_if_tensor_ref(graph, v_cache_data, utils::kWidthPacked); + prepack_standard_like(graph, v_cache_data, q_projected); const int32_t max_seq_len = graph.size_at(1, k_cache); diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 4a709fce994..ac7b223eff8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -110,107 +110,99 @@ void add_tensor_to_staging_node( {SV(graph.packed_dim_of(in_tensor))})); } -ValueRef prepack( +void add_prepack_standard_node( ComputeGraph& graph, - const ValueRef vref, - const utils::GPUMemoryLayout layout) { - ValueRef v = graph.add_tensor_like(vref, layout); - + const ValueRef tensor_data, + const ValueRef tensor) { vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( - *graph.get_tensor(v), graph.int8_buffers_enabled()); + *graph.get_tensor(tensor), graph.int8_buffers_enabled()); vkapi::ParamsBindList ubos; - if (graph.is_buffer_storage(v)) { - ubos.append({graph.sizes_ubo(v), graph.strides_ubo(v), graph.numel_ubo(v)}); + if (graph.is_buffer_storage(tensor)) { + ubos.append( + {graph.sizes_ubo(tensor), + graph.strides_ubo(tensor), + graph.numel_ubo(tensor)}); } else { - ubos.append({graph.sizes_ubo(v), graph.axis_map_ubo(v)}); + ubos.append({graph.sizes_ubo(tensor), graph.axis_map_ubo(tensor)}); } graph.prepack_nodes().emplace_back(new PrepackNode( graph, shader, - graph.create_global_wg_size(v), - graph.create_local_wg_size(v), + graph.create_global_wg_size(tensor), + graph.create_local_wg_size(tensor), // Input and Outputs - vref, - v, + tensor_data, + tensor, // Parameter Buffers ubos, // Specialization Constants - {SV(graph.packed_dim_of(v))})); + {SV(graph.packed_dim_of(tensor))})); +} - return v; +ValueRef prepack_standard( + ComputeGraph& graph, + const ValueRef tensor_data, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout layout, + const bool passthrough) { + if (passthrough && graph.val_is_tensor(tensor_data)) { + return tensor_data; + } + VK_CHECK_COND(graph.val_is_tref(tensor_data)); + ValueRef tensor = graph.add_tensor_like(tensor_data, storage_type, layout); + add_prepack_standard_node(graph, tensor_data, tensor); + return tensor; } -ValueRef prepack_buffer( +ValueRef prepack_standard_like( ComputeGraph& graph, - const ValueRef vref, - const utils::GPUMemoryLayout layout) { - ValueRef v = graph.add_tensor_like(vref, utils::kBuffer, layout); + const ValueRef tensor_data, + const ValueRef to_copy, + const bool passthrough) { + VK_CHECK_COND(graph.val_is_tensor(to_copy)); + return prepack_standard( + graph, + tensor_data, + graph.storage_type_of(to_copy), + graph.estimate_memory_layout_of(to_copy), + passthrough); +} +void add_prepack_direct_copy_buffer_node( + ComputeGraph& graph, + const ValueRef tensor_data, + const ValueRef tensor) { std::string kernel_name = "buffer_to_buffer"; - add_dtype_suffix(kernel_name, graph.dtype_of(vref)); + add_dtype_suffix(kernel_name, graph.dtype_of(tensor_data)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); vkapi::ParamsBindList ubos; - ubos.append({graph.numel_ubo(v)}); + ubos.append({graph.numel_ubo(tensor)}); graph.prepack_nodes().emplace_back(new PrepackNode( graph, shader, - graph.create_global_wg_size(v), - graph.create_local_wg_size(v), + graph.create_global_wg_size(tensor), + graph.create_local_wg_size(tensor), // Input and Outputs - vref, - v, + tensor_data, + tensor, // Parameter Buffers ubos, // Specialization Constants {})); - - return v; -} - -ValueRef prepack_if_tensor_ref( - ComputeGraph& graph, - const ValueRef v, - const utils::GPUMemoryLayout layout) { - if (graph.val_is_tref(v)) { - return prepack(graph, v, layout); - } else { - return v; - } } -ValueRef prepack_buffer_if_tensor_ref( +ValueRef prepack_direct_copy_buffer( ComputeGraph& graph, - const ValueRef v, - const utils::GPUMemoryLayout layout) { - if (graph.val_is_tref(v)) { - return prepack_buffer(graph, v, layout); - } else { - return v; - } -} - -ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) { - if (graph.val_is_tref(v)) { - utils::GPUMemoryLayout layout = - graph.suggested_memory_layout(graph.get_tref(v)->sizes); - return prepack(graph, v, layout); - } else { - return v; - } -} - -ValueRef prepack_buffer_if_tensor_ref(ComputeGraph& graph, const ValueRef v) { - if (graph.val_is_tref(v)) { - utils::GPUMemoryLayout layout = - graph.suggested_memory_layout(graph.get_tref(v)->sizes); - return prepack_buffer(graph, v, layout); - } else { - return v; - } + const ValueRef tensor_data) { + VK_CHECK_COND(graph.val_is_tref(tensor_data)); + ValueRef tensor = + graph.add_tensor_like(tensor_data, utils::kBuffer, utils::kWidthPacked); + add_prepack_direct_copy_buffer_node(graph, tensor_data, tensor); + return tensor; } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 88a9630239a..add9162d85f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -14,6 +14,10 @@ namespace vkcompute { +// +// Staging Buffer <-> Tensor +// + void add_staging_to_tensor_node( ComputeGraph& graph, const ValueRef in_staging, @@ -24,18 +28,50 @@ void add_tensor_to_staging_node( const ValueRef in_tensor, const ValueRef out_staging); -ValueRef prepack_if_tensor_ref( +// +// Standard Prepack +// + +/* + * Given that `v` is a `TensorRef`, create a new `Tensor` value with the + * specified `storage_type` and `memory_layout`, and add a a prepacking node to + * transfer the `TensorRef` data to the new `Tensor` object via a staging to + * tensor shader. The created `Tensor` value is then returned. + * + * If `passthrough` is `true`, then `v` may be a `Tensor` as well. If `v` is a + * `Tensor`, then it is returned as-is. If `passthrough` is `false` (default), + * then an exception will be thrown. + */ + +ValueRef prepack_standard( ComputeGraph& graph, - const ValueRef v, - const utils::GPUMemoryLayout layout); + const ValueRef tensor_data, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout layout, + const bool passthrough = false); -ValueRef prepack_buffer_if_tensor_ref( +/* + * Equivalent to `prepack_standard()` function, except the `storage_type` and + * `memory_layout` are set to match `to_copy`, which must be a `Tensor`. + */ +ValueRef prepack_standard_like( ComputeGraph& graph, - const ValueRef v, - const utils::GPUMemoryLayout layout); + const ValueRef tensor_data, + const ValueRef to_copy, + const bool passthrough = false); -ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v); +// +// Direct buffer copy prepack +// -ValueRef prepack_buffer_if_tensor_ref(ComputeGraph& graph, const ValueRef v); +/* + * Given that `v` is a `TensorRef`, create a new `Tensor` value with buffer + * storage and `kWidthPacked` memory layout, and add a prepacking node to + * transfer the `TensorRef` data to the new `Tensor` object via a direct buffer + * to buffer copy shader. + */ +ValueRef prepack_direct_copy_buffer( + ComputeGraph& graph, + const ValueRef tensor_data); } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp index 7c0440ac052..7dd3762ecf1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp @@ -57,8 +57,6 @@ void add_sum_dim_node( const int64_t dim, const bool keepdim, const ValueRef out) { - ValueRef arg = prepack_if_tensor_ref(graph, in); - vTensorPtr t_out = graph.get_tensor(out); vTensorPtr t_input = graph.get_tensor(in); @@ -83,7 +81,7 @@ void add_sum_dim_node( graph.create_local_wg_size(out), // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, - {arg, vkapi::MemoryAccessType::READ}}, + {in, vkapi::MemoryAccessType::READ}}, // Shader params buffers {t_out->logical_limits_ubo(), graph.create_params_buffer(dim + 4 - in_dim), diff --git a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp index 225c3918c4d..d1145a925d4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp @@ -37,8 +37,7 @@ void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) { " <-> ", vkapi::to_string(graph.dtype_of(out))); - graph.execute_nodes().emplace_back( - new BlitNode(graph, prepack_if_tensor_ref(graph, in), out)); + graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); } void to_copy(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp index c9463be2da1..73f8055c284 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp @@ -63,8 +63,6 @@ void add_upsample_nearest2d_node( "Invalid input, must provide ONLY one of output_sizes or scale_factors"); } - ValueRef arg_in = prepack_if_tensor_ref(graph, in); - vTensorPtr t_in = graph.get_tensor(in); utils::uvec3 input_sizes = t_in->logical_limits(); @@ -103,7 +101,7 @@ void add_upsample_nearest2d_node( graph.create_local_wg_size(out), // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, - {arg_in, vkapi::MemoryAccessType::READ}}, + {in, vkapi::MemoryAccessType::READ}}, // Shader params buffers {t_out->logical_limits_ubo(), graph.create_params_buffer(input_size), diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 2d94dd39957..889d3282aae 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -155,7 +155,8 @@ def get_weight_int8pack_mm_inputs(): test_suite.dtypes = ["at::kFloat", "at::kHalf"] test_suite.layouts = ["utils::kWidthPacked"] test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] - test_suite.prepacked_args = ["mat2"] + test_suite.prepacked_args = ["mat2", "scales"] + test_suite.requires_prepack = True test_suite.arg_dtype["mat2"] = "at::kChar" test_suite.arg_data_range["mat2"] = (0, 100) @@ -1084,6 +1085,8 @@ def get_native_batch_norm_inputs(): ] test_suite = VkTestSuite(test_cases) + test_suite.requires_prepack = True + test_suite.prepacked_args = ["weight", "bias", "mean", "var"] return test_suite diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index d1303d14ebb..7ccfa89e8e7 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -950,11 +950,10 @@ def test_vulkan_backend_native_layer_norm(self): class NativeLayerNormModule(torch.nn.Module): def __init__(self): super().__init__() + self.layer_norm = torch.nn.LayerNorm(5) def forward(self, x): - return torch.native_layer_norm( - x, [5], torch.ones(5), torch.zeros(5), 1e-5 - ) + return self.layer_norm(x) sample_inputs = (torch.randn(size=(3, 4, 5), dtype=torch.float32),)