@@ -63,6 +63,7 @@ DECL_VALUE_PTR_CLASS(IntListPtr, std::vector<int64_t>)
63
63
DECL_VALUE_PTR_CLASS (DoubleListPtr, std::vector<double >)
64
64
DECL_VALUE_PTR_CLASS (BoolListPtr, std::vector<bool >)
65
65
DECL_VALUE_PTR_CLASS (ValueListPtr, std::vector<ValueRef>)
66
+ DECL_VALUE_PTR_CLASS (SymIntPtr, SymInt);
66
67
67
68
#undef DECL_VALUE_PTR_CLASS
68
69
@@ -154,6 +155,7 @@ class ComputeGraph final {
154
155
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS (DoubleListPtr, double_list, DoubleList)
155
156
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS (BoolListPtr, bool_list, BoolList)
156
157
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS (ValueListPtr, value_list, ValueList)
158
+ GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS (SymIntPtr, symint, SymInt);
157
159
158
160
#undef GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS
159
161
@@ -422,15 +424,28 @@ class ComputeGraph final {
422
424
423
425
ValueRef add_string (std::string&& str);
424
426
427
+ ValueRef add_symint (const int32_t val);
428
+
425
429
ValueRef set_input_tensor (const ValueRef idx, const bool use_staging = true );
426
430
ValueRef set_output_tensor (const ValueRef idx, const bool use_staging = true );
427
431
428
432
template <typename Block>
429
- const vkapi::BufferBindInfo create_params_buffer (const Block& data) {
433
+ vkapi::BufferBindInfo create_params_buffer (const Block& data) {
430
434
param_ubos_.emplace_back (api::ParamsBuffer (context_.get (), data));
431
435
return vkapi::BufferBindInfo (param_ubos_.back ().buffer ());
432
436
}
433
437
438
+ /*
439
+ * Given a ValueRef, do the following depending on the type of the Value:
440
+ * - If it is a SymInt, return the BufferBindInfo of the ParamsBuffer object
441
+ * backing the SymInt.
442
+ * - If it is a regular Int, create a new ParamsBuffer using the integer value
443
+ * and return the BufferBindInfo of the created ParamsBuffer.
444
+ */
445
+ vkapi::BufferBindInfo get_or_create_int_param_buffer (const ValueRef idx);
446
+
447
+ void set_symint (const ValueRef idx, const int32_t val);
448
+
434
449
/*
435
450
* Convenience function to add an input tensor along with its staging buffer
436
451
*/
@@ -577,6 +592,7 @@ class ComputeGraph final {
577
592
friend class DoubleListPtr ;
578
593
friend class BoolListPtr ;
579
594
friend class ValueListPtr ;
595
+ friend class SymIntPtr ;
580
596
};
581
597
582
598
template <typename T>
0 commit comments