diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 0486110ced6..754696f7ac5 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -259,6 +259,17 @@ def register_binary_op(features: OpFeatures): valid_packed_dims=all_packed_dims, ) features.resize_fn = True + features.buffer_impl = True + return features + + +@update_features(exir_ops.edge.aten.where.self) +def register_where_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims=all_packed_dims, + ) + features.buffer_impl = True + features.resize_fn = True return features