Skip to content

Commit 775ab35

Browse files
committed
Update on "[ET-VK] Adding get or create int function to read int value."
This diff adds a new function `get_or_create_int` to the `ComputeGraph` class, which allows reading an integer value from a `ValueRef` index. The function returns the extracted integer value if the value at the index is an integer, otherwise it throws an error. Additionally, an overload of the function is added to return a default value if the value at the index is `None`. Differential Revision: [D78094858](https://our.internmc.facebook.com/intern/diff/D78094858/) [ghstack-poisoned]
2 parents 48ea65b + 344febd commit 775ab35

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -549,15 +549,6 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
549549
}
550550
}
551551

552-
int32_t ComputeGraph::get_or_create_int(
553-
const ValueRef idx,
554-
const int32_t default_val) {
555-
if (values_.at(idx).isNone()) {
556-
return default_val;
557-
}
558-
return extract_scalar<int32_t>(idx);
559-
}
560-
561552
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
562553
get_symint(idx)->set(val);
563554
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class ComputeGraph final {
424424
// Scalar Value Extraction
425425
//
426426

427-
bool is_scalar(const ValueRef idx) const {
427+
bool is_scalar_or_none(const ValueRef idx) const {
428428
const Value& value = values_.at(idx);
429429
return value.isInt() || value.isDouble() || value.isBool() ||
430430
value.isNone();
@@ -445,6 +445,15 @@ class ComputeGraph final {
445445
VK_THROW("Cannot extract scalar from Value with type ", value.type());
446446
}
447447

448+
template <typename T>
449+
T extract_scalar_or(const ValueRef idx, const T default_value) {
450+
Value& value = values_.at(idx);
451+
if (value.isNone()) {
452+
return default_value;
453+
}
454+
return extract_scalar<T>(idx);
455+
}
456+
448457
template <typename T>
449458
std::optional<T> extract_optional_scalar(const ValueRef idx) {
450459
if (val_is_none(idx)) {
@@ -685,8 +694,6 @@ class ComputeGraph final {
685694
const ValueRef idx,
686695
const int32_t default_value);
687696

688-
int32_t get_or_create_int(const ValueRef idx, const int32_t default_value);
689-
690697
void set_symint(const ValueRef idx, const int32_t val);
691698

692699
int32_t read_symint(const ValueRef idx);

0 commit comments

Comments
 (0)