@@ -58,7 +58,7 @@ struct Value final {
5858 bool as_bool;
5959 } u;
6060
61- api::vTensor as_tensor;
61+ std::unique_ptr< api::vTensor> as_tensor;
6262 api::StagingBuffer as_staging;
6363 TensorRef as_tensorref;
6464
@@ -106,15 +106,18 @@ struct Value final {
106106 rhs.payload.member_name.~dtor_name (); \
107107 break ;
108108
109+ #define CASE_MOVE_UNIQUE_PTR_TYPE (type_tag, member_name ) \
110+ case type_tag: \
111+ payload.member_name = std::move(rhs.payload.member_name); \
112+ break ;
113+
109114 Value (Value&& rhs) noexcept : tag(rhs.tag) {
110115 switch (tag) {
111116 // Scalar types
112117 CASE_MOVE_TRIVIALLY_COPYABLE_TYPE (TypeTag::INT, as_int);
113118 CASE_MOVE_TRIVIALLY_COPYABLE_TYPE (TypeTag::DOUBLE, as_double);
114119 CASE_MOVE_TRIVIALLY_COPYABLE_TYPE (TypeTag::BOOL, as_bool);
115- // Tensor and tensor adjacent types
116- CASE_MOVE_MOVEABLE_TYPE (
117- TypeTag::TENSOR, api::vTensor, as_tensor, vTensor);
120+ // Tensor adjacent types
118121 CASE_MOVE_MOVEABLE_TYPE (
119122 TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer);
120123 CASE_MOVE_MOVEABLE_TYPE (
@@ -132,6 +135,8 @@ struct Value final {
132135 CASE_MOVE_MOVEABLE_TYPE (
133136 TypeTag::STRING, std::string, as_string, basic_string);
134137 CASE_MOVE_MOVEABLE_TYPE (TypeTag::SYMINT, SymInt, as_symint, SymInt);
138+ // Tensor type
139+ CASE_MOVE_UNIQUE_PTR_TYPE (TypeTag::TENSOR, as_tensor);
135140
136141 case TypeTag::NONE:
137142 clearToNone ();
@@ -142,6 +147,7 @@ struct Value final {
142147
143148#undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE
144149#undef CASE_MOVE_MOVEABLE_TYPE
150+ #undef CASE_MOVE_UNIQUE_PTR_TYPE
145151
146152 //
147153 // Accessors
@@ -157,9 +163,6 @@ struct Value final {
157163
158164 ~Value () {
159165 switch (tag) {
160- case TypeTag::TENSOR:
161- payload.as_tensor .~vTensor ();
162- break ;
163166 case TypeTag::STAGING:
164167 payload.as_staging .~StagingBuffer ();
165168 break ;
@@ -184,6 +187,9 @@ struct Value final {
184187 case TypeTag::SYMINT:
185188 payload.as_symint .~SymInt ();
186189 break ;
190+ case TypeTag::TENSOR:
191+ payload.as_tensor .reset ();
192+ break ;
187193 // Manually list out the types so that if a type here is added later and
188194 // not handled the compiler can catch it.
189195 case TypeTag::NONE:
@@ -252,12 +258,6 @@ struct Value final {
252258 return payload.member_name ; \
253259 }
254260
255- SUPPORT_TRIVIALLY_MOVEABLE_TYPE (
256- api::vTensor,
257- Tensor,
258- TypeTag::TENSOR,
259- as_tensor);
260-
261261 SUPPORT_TRIVIALLY_MOVEABLE_TYPE (
262262 api::StagingBuffer,
263263 Staging,
@@ -302,9 +302,36 @@ struct Value final {
302302
303303 SUPPORT_TRIVIALLY_MOVEABLE_TYPE (SymInt, SymInt, TypeTag::SYMINT, as_symint);
304304
305- #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
306305#undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
307306
307+ #define SUPPORT_UNIQUE_PTR_TYPE (type, type_name, type_tag, member_name ) \
308+ explicit Value (type t) : tag(type_tag) { \
309+ payload.member_name = std::make_unique<type>(std::move (t)); \
310+ } \
311+ inline bool is##type_name() const { \
312+ return tag == type_tag; \
313+ } \
314+ inline type& to##type_name() const { \
315+ VK_CHECK_COND ( \
316+ is##type_name (), \
317+ " Expected value to have type " #type_name " , got " , \
318+ tag, \
319+ " instead." ); \
320+ return *payload.member_name ; \
321+ } \
322+ inline const type& toConst##type_name() const { \
323+ VK_CHECK_COND ( \
324+ is##type_name (), \
325+ " Expected value to have type " #type_name " , got " , \
326+ tag, \
327+ " instead." ); \
328+ return *payload.member_name ; \
329+ }
330+
331+ SUPPORT_UNIQUE_PTR_TYPE (api::vTensor, Tensor, TypeTag::TENSOR, as_tensor);
332+
333+ #undef SUPPORT_UNIQUE_PTR_TYPE
334+
308335 private:
309336 Payload payload;
310337 TypeTag tag;
0 commit comments