|  | 
| 15 | 15 | #include <cstring> | 
| 16 | 16 | #include <memory> | 
| 17 | 17 | #include <string> | 
|  | 18 | +#include <variant> | 
| 18 | 19 | 
 | 
| 19 | 20 | // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime. | 
| 20 | 21 | #include <executorch/extension/pytree/function_ref.h> | 
| @@ -55,51 +56,36 @@ using KeyInt = int32_t; | 
| 55 | 56 | struct Key { | 
| 56 | 57 |   enum class Kind : uint8_t { None, Int, Str } kind_; | 
| 57 | 58 | 
 | 
| 58 |  | -  KeyInt as_int_ = {}; | 
| 59 |  | -  KeyStr as_str_ = {}; | 
|  | 59 | + private: | 
|  | 60 | +  std::variant<std::monostate, KeyInt, KeyStr> repr_; | 
| 60 | 61 | 
 | 
| 61 |  | -  Key() : kind_(Kind::None) {} | 
| 62 |  | -  /*implicit*/ Key(KeyInt key) : kind_(Kind::Int), as_int_(std::move(key)) {} | 
| 63 |  | -  /*implicit*/ Key(KeyStr key) : kind_(Kind::Str), as_str_(std::move(key)) {} | 
|  | 62 | + public: | 
|  | 63 | +  Key() {} | 
|  | 64 | +  /*implicit*/ Key(KeyInt key) : repr_(key) {} | 
|  | 65 | +  /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {} | 
| 64 | 66 | 
 | 
| 65 |  | -  const Kind& kind() const { | 
| 66 |  | -    return kind_; | 
|  | 67 | +  Kind kind() const { | 
|  | 68 | +    return static_cast<Kind>(repr_.index()); | 
| 67 | 69 |   } | 
| 68 | 70 | 
 | 
| 69 |  | -  const KeyInt& as_int() const { | 
| 70 |  | -    pytree_assert(kind_ == Key::Kind::Int); | 
| 71 |  | -    return as_int_; | 
|  | 71 | +  KeyInt as_int() const { | 
|  | 72 | +    return std::get<KeyInt>(repr_); | 
| 72 | 73 |   } | 
| 73 | 74 | 
 | 
| 74 |  | -  operator const KeyInt&() const { | 
|  | 75 | +  operator KeyInt() const { | 
| 75 | 76 |     return as_int(); | 
| 76 | 77 |   } | 
| 77 | 78 | 
 | 
| 78 | 79 |   const KeyStr& as_str() const { | 
| 79 |  | -    pytree_assert(kind_ == Key::Kind::Str); | 
| 80 |  | -    return as_str_; | 
|  | 80 | +    return std::get<KeyStr>(repr_); | 
| 81 | 81 |   } | 
| 82 | 82 | 
 | 
| 83 | 83 |   operator const KeyStr&() const { | 
| 84 | 84 |     return as_str(); | 
| 85 | 85 |   } | 
| 86 | 86 | 
 | 
| 87 | 87 |   bool operator==(const Key& rhs) const { | 
| 88 |  | -    if (kind_ != rhs.kind_) { | 
| 89 |  | -      return false; | 
| 90 |  | -    } | 
| 91 |  | -    switch (kind_) { | 
| 92 |  | -      case Kind::Str: { | 
| 93 |  | -        return as_str_ == rhs.as_str_; | 
| 94 |  | -      } | 
| 95 |  | -      case Kind::Int: { | 
| 96 |  | -        return as_int_ == rhs.as_int_; | 
| 97 |  | -      } | 
| 98 |  | -      case Kind::None: { | 
| 99 |  | -        return true; | 
| 100 |  | -      } | 
| 101 |  | -    } | 
| 102 |  | -    pytree_unreachable(); | 
|  | 88 | +    return repr_ == rhs.repr_; | 
| 103 | 89 |   } | 
| 104 | 90 | 
 | 
| 105 | 91 |   bool operator!=(const Key& rhs) const { | 
|  | 
0 commit comments