Skip to content

Commit 29d9107

Browse files
committed
Update base for Update on "[ET-VK] Adding push constant and ubo verison of select and slice ops to improve memory and performance."
Adding push constant and ubo verison of select and slice ops to improve memory and performance. * Updated `transfer_buffer.yaml` and `transfer_texture.yaml` to include `UBO_PARAMS` parameter and generate variants for `select` and `slice` ops with UBO parameters. * Updated `transfer.glsl` to generate ubo and push constant versions of `select` and `slice` ops with UBO parameters. Differential Revision: [D78095262](https://our.internmc.facebook.com/intern/diff/D78095262/) [ghstack-poisoned]
1 parent 85f2193 commit 29d9107

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -549,15 +549,6 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
549549
}
550550
}
551551

552-
int32_t ComputeGraph::get_or_create_int(
553-
const ValueRef idx,
554-
const int32_t default_val) {
555-
if (values_.at(idx).isNone()) {
556-
return default_val;
557-
}
558-
return extract_scalar<int32_t>(idx);
559-
}
560-
561552
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
562553
get_symint(idx)->set(val);
563554
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class ComputeGraph final {
424424
// Scalar Value Extraction
425425
//
426426

427-
bool is_scalar(const ValueRef idx) const {
427+
bool is_scalar_or_none(const ValueRef idx) const {
428428
const Value& value = values_.at(idx);
429429
return value.isInt() || value.isDouble() || value.isBool() ||
430430
value.isNone();
@@ -445,6 +445,15 @@ class ComputeGraph final {
445445
VK_THROW("Cannot extract scalar from Value with type ", value.type());
446446
}
447447

448+
template <typename T>
449+
T extract_scalar_or(const ValueRef idx, const T default_value) {
450+
Value& value = values_.at(idx);
451+
if (value.isNone()) {
452+
return default_value;
453+
}
454+
return extract_scalar<T>(idx);
455+
}
456+
448457
template <typename T>
449458
std::optional<T> extract_optional_scalar(const ValueRef idx) {
450459
if (val_is_none(idx)) {
@@ -685,8 +694,6 @@ class ComputeGraph final {
685694
const ValueRef idx,
686695
const int32_t default_value);
687696

688-
int32_t get_or_create_int(const ValueRef idx, const int32_t default_value);
689-
690697
void set_symint(const ValueRef idx, const int32_t val);
691698

692699
int32_t read_symint(const ValueRef idx);

0 commit comments

Comments
 (0)