Skip to content

Commit f264487

Browse files
committed
Update on "[ET-VK] Adding push constant and ubo verison of select and slice ops to improve memory and performance."
Adding push constant and ubo verison of select and slice ops to improve memory and performance. * Updated `transfer_buffer.yaml` and `transfer_texture.yaml` to include `UBO_PARAMS` parameter and generate variants for `select` and `slice` ops with UBO parameters. * Updated `transfer.glsl` to generate ubo and push constant versions of `select` and `slice` ops with UBO parameters. Differential Revision: [D78095262](https://our.internmc.facebook.com/intern/diff/D78095262/) [ghstack-poisoned]
2 parents 7f8765e + 29d9107 commit f264487

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -549,15 +549,6 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
549549
}
550550
}
551551

552-
int32_t ComputeGraph::get_or_create_int(
553-
const ValueRef idx,
554-
const int32_t default_val) {
555-
if (values_.at(idx).isNone()) {
556-
return default_val;
557-
}
558-
return extract_scalar<int32_t>(idx);
559-
}
560-
561552
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
562553
get_symint(idx)->set(val);
563554
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class ComputeGraph final {
424424
// Scalar Value Extraction
425425
//
426426

427-
bool is_scalar(const ValueRef idx) const {
427+
bool is_scalar_or_none(const ValueRef idx) const {
428428
const Value& value = values_.at(idx);
429429
return value.isInt() || value.isDouble() || value.isBool() ||
430430
value.isNone();
@@ -445,6 +445,15 @@ class ComputeGraph final {
445445
VK_THROW("Cannot extract scalar from Value with type ", value.type());
446446
}
447447

448+
template <typename T>
449+
T extract_scalar_or(const ValueRef idx, const T default_value) {
450+
Value& value = values_.at(idx);
451+
if (value.isNone()) {
452+
return default_value;
453+
}
454+
return extract_scalar<T>(idx);
455+
}
456+
448457
template <typename T>
449458
std::optional<T> extract_optional_scalar(const ValueRef idx) {
450459
if (val_is_none(idx)) {
@@ -685,8 +694,6 @@ class ComputeGraph final {
685694
const ValueRef idx,
686695
const int32_t default_value);
687696

688-
int32_t get_or_create_int(const ValueRef idx, const int32_t default_value);
689-
690697
void set_symint(const ValueRef idx, const int32_t val);
691698

692699
int32_t read_symint(const ValueRef idx);

backends/vulkan/runtime/graph/ops/impl/Transfer.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ void add_transfer_copy_node(
4646
int32_t step_ref;
4747
} transfer_params{static_cast<int32_t>(dim_whcn), 0, 0};
4848

49-
const bool param_is_scalar = graph.is_scalar(index_or_start_ref) &&
50-
(transfer_type == TransferType::SELECT || graph.is_scalar(step_ref));
49+
const bool param_is_scalar = graph.is_scalar_or_none(index_or_start_ref) &&
50+
(transfer_type == TransferType::SELECT ||
51+
graph.is_scalar_or_none(step_ref));
5152

5253
vkapi::ParamsBindList param_buffers;
5354
if (!param_is_scalar) {
@@ -61,9 +62,9 @@ void add_transfer_copy_node(
6162
}
6263
} else {
6364
transfer_params.index_or_start_ref =
64-
graph.get_or_create_int(index_or_start_ref, 0);
65+
graph.extract_scalar_or<int32_t>(index_or_start_ref, 0);
6566
if (transfer_type != TransferType::SELECT) {
66-
transfer_params.step_ref = graph.get_or_create_int(step_ref, 1);
67+
transfer_params.step_ref = graph.extract_scalar_or<int32_t>(step_ref, 1);
6768
}
6869
}
6970

0 commit comments

Comments
 (0)