Skip to content

Commit 35d0f59

Browse files
authored
[ET-VK] Add type for symbolic integers
Differential Revision: D62144399 Pull Request resolved: #5040
1 parent f65531b commit 35d0f59

File tree

9 files changed

+204
-1
lines changed

9 files changed

+204
-1
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ VALUE_PTR_CLASS_IMPL(IntListPtr, std::vector<int64_t>, IntList)
4343
VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector<double>, DoubleList)
4444
VALUE_PTR_CLASS_IMPL(BoolListPtr, std::vector<bool>, BoolList)
4545
VALUE_PTR_CLASS_IMPL(ValueListPtr, std::vector<ValueRef>, ValueList)
46+
VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)
4647

4748
#undef VALUE_PTR_CLASS_IMPL
4849

@@ -261,6 +262,13 @@ ValueRef ComputeGraph::add_string(std::string&& str) {
261262
return idx;
262263
}
263264

265+
ValueRef ComputeGraph::add_symint(const int32_t val) {
266+
ValueRef idx(static_cast<int>(values_.size()));
267+
check_no_active_value_ptrs();
268+
values_.emplace_back(SymInt(context(), val));
269+
return idx;
270+
}
271+
264272
ValueRef ComputeGraph::set_input_tensor(
265273
const ValueRef idx,
266274
const bool use_staging) {
@@ -300,6 +308,22 @@ ValueRef ComputeGraph::set_output_tensor(
300308
return idx;
301309
}
302310

311+
vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
312+
const ValueRef idx) {
313+
if (values_.at(idx).isInt()) {
314+
const int32_t val = extract_scalar<int32_t>(idx);
315+
create_params_buffer(val);
316+
} else if (values_.at(idx).isSymInt()) {
317+
SymIntPtr symint = get_symint(idx);
318+
return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
319+
}
320+
VK_THROW("Cannot create a int param buffer for the given value");
321+
}
322+
323+
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
324+
get_symint(idx)->set(val);
325+
}
326+
303327
SharedObject& ComputeGraph::get_shared_object(const int64_t idx) {
304328
if (idx >= shared_objects_.size()) {
305329
shared_objects_.resize(static_cast<size_t>(idx + 1));

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ DECL_VALUE_PTR_CLASS(IntListPtr, std::vector<int64_t>)
6363
DECL_VALUE_PTR_CLASS(DoubleListPtr, std::vector<double>)
6464
DECL_VALUE_PTR_CLASS(BoolListPtr, std::vector<bool>)
6565
DECL_VALUE_PTR_CLASS(ValueListPtr, std::vector<ValueRef>)
66+
DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt);
6667

6768
#undef DECL_VALUE_PTR_CLASS
6869

@@ -154,6 +155,7 @@ class ComputeGraph final {
154155
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(DoubleListPtr, double_list, DoubleList)
155156
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(BoolListPtr, bool_list, BoolList)
156157
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(ValueListPtr, value_list, ValueList)
158+
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(SymIntPtr, symint, SymInt);
157159

158160
#undef GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS
159161

@@ -422,15 +424,28 @@ class ComputeGraph final {
422424

423425
ValueRef add_string(std::string&& str);
424426

427+
ValueRef add_symint(const int32_t val);
428+
425429
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
426430
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
427431

428432
template <typename Block>
429-
const vkapi::BufferBindInfo create_params_buffer(const Block& data) {
433+
vkapi::BufferBindInfo create_params_buffer(const Block& data) {
430434
param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data));
431435
return vkapi::BufferBindInfo(param_ubos_.back().buffer());
432436
}
433437

438+
/*
439+
* Given a ValueRef, do the following depending on the type of the Value:
440+
* - If it is a SymInt, return the BufferBindInfo of the ParamsBuffer object
441+
* backing the SymInt.
442+
* - If it is a regular Int, create a new ParamsBuffer using the integer value
443+
* and return the BufferBindInfo of the created ParamsBuffer.
444+
*/
445+
vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);
446+
447+
void set_symint(const ValueRef idx, const int32_t val);
448+
434449
/*
435450
* Convenience function to add an input tensor along with its staging buffer
436451
*/
@@ -577,6 +592,7 @@ class ComputeGraph final {
577592
friend class DoubleListPtr;
578593
friend class BoolListPtr;
579594
friend class ValueListPtr;
595+
friend class SymIntPtr;
580596
};
581597

582598
template <typename T>
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
10+
11+
namespace vkcompute {
12+
13+
SymInt::SymInt(api::Context* context_p, const int32_t val)
14+
: gpu_buffer(context_p, val){};
15+
16+
void SymInt::set(const int32_t val) {
17+
gpu_buffer.update(val);
18+
}
19+
20+
void SymInt::operator=(const int32_t val) {
21+
gpu_buffer.update(val);
22+
}
23+
24+
} // namespace vkcompute
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/Context.h>
12+
#include <executorch/backends/vulkan/runtime/api/containers/ParamsBuffer.h>
13+
14+
namespace vkcompute {
15+
16+
/*
17+
* Represents a symbolic integer whose value can be variable. It is implemented
18+
* as a thin wrapper around a `ParamsBuffer` object that holds the value of the
19+
* integer. The `ParamsBuffer` object allows the value of the symbolic integer
20+
* to be changed from the CPU and have those changes be visible to all shaders
21+
* that use the symbolic integer; it also allows the value of the symbolic
22+
* integer to be the result of a compute shader.
23+
*
24+
* Regular scalar types represented by `TypeTag::INT` cannot be used for
25+
* symbolic integers because their value is assumed to be constant; therefore
26+
* the `Value` instance holding the value of the scalar does not contain
27+
* any reference to the GPU buffers used to pass its value into compute shaders.
28+
* Therefore, updating the value of the scalar does not impact the value seen
29+
* by compute shaders.
30+
*/
31+
struct SymInt final {
32+
api::ParamsBuffer gpu_buffer;
33+
34+
explicit SymInt(api::Context* context_p, const int32_t val);
35+
36+
void set(const int32_t val);
37+
38+
void operator=(const int32_t val);
39+
};
40+
41+
} // namespace vkcompute

backends/vulkan/runtime/graph/containers/Types.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ std::ostream& operator<<(std::ostream& out, const TypeTag& tag) {
2929
PRINT_CASE(BOOLLIST)
3030
PRINT_CASE(VALUELIST)
3131
PRINT_CASE(STRING)
32+
PRINT_CASE(SYMINT)
3233
}
3334
return out;
3435
}

backends/vulkan/runtime/graph/containers/Types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum class TypeTag : uint32_t {
3636
// Special Type
3737
VALUELIST,
3838
STRING,
39+
SYMINT,
3940
};
4041

4142
std::ostream& operator<<(std::ostream& out, const TypeTag& tag);

backends/vulkan/runtime/graph/containers/Value.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/backends/vulkan/runtime/api/api.h>
1414

1515
#include <executorch/backends/vulkan/runtime/graph/containers/Constant.h>
16+
#include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
1617
#include <executorch/backends/vulkan/runtime/graph/containers/Types.h>
1718

1819
namespace vkcompute {
@@ -67,6 +68,8 @@ struct Value final {
6768

6869
std::string as_string;
6970

71+
SymInt as_symint;
72+
7073
Payload() : u() {}
7174
// NOLINTNEXTLINE
7275
~Payload(){};
@@ -123,6 +126,7 @@ struct Value final {
123126
TypeTag::VALUELIST, std::vector<ValueRef>, as_value_list, vector);
124127
CASE_MOVE_MOVEABLE_TYPE(
125128
TypeTag::STRING, std::string, as_string, basic_string);
129+
CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt);
126130

127131
case TypeTag::NONE:
128132
clearToNone();
@@ -172,6 +176,9 @@ struct Value final {
172176
case TypeTag::STRING:
173177
payload.as_string.~basic_string();
174178
break;
179+
case TypeTag::SYMINT:
180+
payload.as_symint.~SymInt();
181+
break;
175182
// Manually list out the types so that if a type here is added later and
176183
// not handled the compiler can catch it.
177184
case TypeTag::NONE:
@@ -288,6 +295,8 @@ struct Value final {
288295
TypeTag::STRING,
289296
as_string);
290297

298+
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint);
299+
291300
#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
292301
#undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
293302

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
layout(std430) buffer;
14+
15+
${layout_declare_tensor(0, "rw", "t_in", "float", "texture3d")}
16+
${layout_declare_ubo(1, "uvec3", "extents")}
17+
${layout_declare_ubo(2, "int", "scalar")}
18+
19+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
20+
21+
void main() {
22+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
23+
if (any(greaterThanEqual(pos, extents))) {
24+
return;
25+
}
26+
27+
vec4 in_tex = imageLoad(t_in, pos);
28+
imageStore(t_in, pos, imageLoad(t_in, pos) + float(scalar));
29+
}

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,64 @@ TEST(VulkanComputeGraphTest, test_simple_graph) {
12681268
}
12691269
}
12701270

1271+
TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) {
1272+
GraphConfig config;
1273+
config.set_storage_type_override(utils::kTexture3D);
1274+
ComputeGraph graph(config);
1275+
1276+
std::vector<int64_t> sizes = {8, 64, 124};
1277+
1278+
// Build graph
1279+
1280+
ValueRef scalar = graph.add_symint(1);
1281+
IOValueRef a = graph.add_input_tensor(sizes, vkapi::kFloat);
1282+
1283+
IOValueRef out = {};
1284+
out.value = a.value;
1285+
1286+
graph.execute_nodes().emplace_back(new ExecuteNode(
1287+
graph,
1288+
VK_KERNEL_FROM_STR("scalar_add_texture"),
1289+
graph.create_global_wg_size(a.value),
1290+
graph.create_local_wg_size(a.value),
1291+
// Inputs and Outputs
1292+
{{out.value, vkapi::MemoryAccessType::WRITE}},
1293+
// Shader params buffers
1294+
{graph.texture_limits_ubo(a.value),
1295+
graph.get_or_create_int_param_buffer(scalar)},
1296+
// Specialization Constants
1297+
{},
1298+
// Resizing Logic
1299+
nullptr,
1300+
{}));
1301+
1302+
out.staging = graph.set_output_tensor(out.value);
1303+
1304+
graph.prepare();
1305+
graph.encode_execute();
1306+
1307+
// Run graph
1308+
1309+
for (float i = 5.0f; i < 30.0f; i += 10.0f) {
1310+
int scalar_val = i - 3.0f;
1311+
graph.set_symint(scalar, scalar_val);
1312+
1313+
float val_a = i + 2.0f;
1314+
float val_out = val_a + scalar_val;
1315+
1316+
fill_vtensor(graph, a, val_a);
1317+
1318+
graph.execute();
1319+
1320+
EXTRACT_TENSOR(out);
1321+
1322+
// Sanity check that the values are correct
1323+
for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) {
1324+
CHECK_VALUE(data_out, i, val_out);
1325+
}
1326+
}
1327+
}
1328+
12711329
#define CREATE_WEIGHT_TENSOR(name, sizes, dtype, val) \
12721330
std::vector<float> data_##name(utils::multiply_integers(sizes)); \
12731331
std::fill(data_##name.begin(), data_##name.end(), val); \

0 commit comments

Comments
 (0)