diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 78135a434e5..5c19a6003e8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -424,6 +424,12 @@ class ComputeGraph final { // Scalar Value Extraction // + bool is_scalar_or_none(const ValueRef idx) const { + const Value& value = values_.at(idx); + return value.isInt() || value.isDouble() || value.isBool() || + value.isNone(); + } + template T extract_scalar(const ValueRef idx) { Value& value = values_.at(idx); @@ -439,6 +445,15 @@ class ComputeGraph final { VK_THROW("Cannot extract scalar from Value with type ", value.type()); } + template + T extract_scalar_or(const ValueRef idx, const T default_value) { + Value& value = values_.at(idx); + if (value.isNone()) { + return default_value; + } + return extract_scalar(idx); + } + template std::optional extract_optional_scalar(const ValueRef idx) { if (val_is_none(idx)) {