Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def register_transfer_ops(features: OpFeatures):
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ones.default,
exir_ops.edge.aten.ones_like.default,
exir_ops.edge.aten.scalar_tensor.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.zeros.default,
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
return val.toConstTensor().dtype();
} else if (val.isTensorRef()) {
return val.toConstTensorRef().dtype;
} else if (val.isBool()) {
return vkapi::ScalarType::Bool;
} else if (val.isDouble()) {
// We downcast anyway in the shader and we want to avoid having to
// write special cases there.
return vkapi::ScalarType::Float;
} else if (val.isInt()) {
return vkapi::ScalarType::Int;
}
VK_THROW("Could not get dtype of value with type ", val.type());
}
Expand Down
55 changes: 55 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/scalar_tensor.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define BUF_T ${buffer_scalar_type(DTYPE)}
#define VEC4_T ${texel_type(DTYPE)}

${define_active_storage_type(STORAGE)}
${define_required_extensions(DTYPE)}
${define_required_extensions(SCALAR_VALUE_TYPE)}

#include "indexing_utils.h"

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_ubo(B, buffer_scalar_type(SCALAR_VALUE_TYPE), "scalar_value")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#ifdef USING_BUFFER

void main() {
const int i = int(gl_GlobalInvocationID.x);

if (i > 0) {
return;
}

t_out[i] = BUF_T(scalar_value);
}

# else // !USING_BUFFER

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

// Scalar tensor is a special case where the packed dim is always 1.
if (any(greaterThanEqual(pos, ivec3(1)))) {
return;
}

VEC4_T outtex = VEC4_T(scalar_value);
write_texel(t_out, pos, outtex);
}

#endif // !USING_BUFFER
27 changes: 27 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/scalar_tensor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

scalar_tensor:
parameter_names_with_default_values:
NDIM: 3
DTYPE: float
SCALAR_VALUE_TYPE: float
PACKING: C_packed
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int32
STORAGE:
- VALUE: texture3d
- VALUE: buffer
SCALAR_VALUE_TYPE:
- VALUE: float
- VALUE: int32
- VALUE: bool
shader_variants:
- NAME: scalar_tensor
54 changes: 54 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/ScalarTensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void scalar_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// Extract the scalar value from the first argument
ValueRef scalar_in = args[0];
float scalar_value = graph.extract_scalar<float>(scalar_in);

// Get the output tensor reference
ValueRef out = args[args.size() - 1];

std::string kernel_name("scalar_tensor");
kernel_name.reserve(kShaderNameReserve);

add_dtype_suffix(kernel_name, graph.dtype_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(scalar_in));

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
// Inputs and Outputs
{{out, vkapi::kWrite}},
// Shader params buffers
{graph.create_params_buffer(scalar_value)},
// Push Constants
{},
// Specialization Constants
{},
// Resize Args
{},
// Resizing Logic
nullptr));
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.scalar_tensor.default, scalar_tensor);
}

} // namespace vkcompute
15 changes: 15 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,21 @@ def get_full_inputs():
return test_suite


@register_test_suite("aten.scalar_tensor.default")
def get_scalar_tensor_inputs():
test_suite = VkTestSuite(
[
(42.0,),
(3.14,),
(2.72,),
(0.0,),
(-1.0,),
(100.0,),
]
)
return test_suite


@register_test_suite(
[
"aten.zeros.default",
Expand Down
Loading