From 8392ed57f87911bf94c3a15265cc386d7e9f3aba Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Tue, 15 Jul 2025 08:34:42 -0700 Subject: [PATCH] [ET-VK] Adding extract_scalar_or function to extract scalar value or return a default if value at index is none. Pull Request resolved: https://github.com/pytorch/executorch/pull/12357 This diff adds a new function `extract_scalar_or` to the `ComputeGraph` class, which extracts a scalar value from a `ValueRef` index. If the value at the index is `None`, it returns a default value. ghstack-source-id: 296319453 @exported-using-ghexport Differential Revision: [D78094858](https://our.internmc.facebook.com/intern/diff/D78094858/) --- backends/vulkan/runtime/graph/ComputeGraph.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 78135a434e5..5c19a6003e8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -424,6 +424,12 @@ class ComputeGraph final { // Scalar Value Extraction // + bool is_scalar_or_none(const ValueRef idx) const { + const Value& value = values_.at(idx); + return value.isInt() || value.isDouble() || value.isBool() || + value.isNone(); + } + template T extract_scalar(const ValueRef idx) { Value& value = values_.at(idx); @@ -439,6 +445,15 @@ class ComputeGraph final { VK_THROW("Cannot extract scalar from Value with type ", value.type()); } + template + T extract_scalar_or(const ValueRef idx, const T default_value) { + Value& value = values_.at(idx); + if (value.isNone()) { + return default_value; + } + return extract_scalar(idx); + } + template std::optional extract_optional_scalar(const ValueRef idx) { if (val_is_none(idx)) {