@@ -38,6 +38,7 @@ limitations under the License.
38
38
/// "third_party/tensorflow/lite/c/common.h".
39
39
/// Only the TensorFlow Lite implementation itself should include this
40
40
/// file directly.
41
+ // IWYU pragma: private, include "third_party/tensorflow/lite/c/common.h"
41
42
42
43
#ifndef TENSORFLOW_LITE_CORE_C_COMMON_H_
43
44
#define TENSORFLOW_LITE_CORE_C_COMMON_H_
@@ -157,6 +158,10 @@ int TfLiteFloatArrayGetSizeInBytes(int size);
157
158
// This returns a pointer, that you must free using TfLiteFloatArrayFree().
158
159
TfLiteFloatArray * TfLiteFloatArrayCreate (int size );
159
160
161
+ // Create a copy of an array passed as `src`.
162
+ // You are expected to free memory with TfLiteFloatArrayFree.
163
+ TfLiteFloatArray * TfLiteFloatArrayCopy (const TfLiteFloatArray * src );
164
+
160
165
// Free memory of array `a`.
161
166
void TfLiteFloatArrayFree (TfLiteFloatArray * a );
162
167
#endif // TF_LITE_STATIC_MEMORY
@@ -345,6 +350,8 @@ typedef union TfLitePtrUnion {
345
350
// as constant inputs for downstream ops (also in prepare).
346
351
// * kTfLiteCustom: Custom memory allocation provided by the user. See
347
352
// TfLiteCustomAllocation below.
353
+ // * kTfLiteVariantObject: Allocation is an arbitrary type-erased C++ object.
354
+ // Allocation and deallocation are done through `new` and `delete`.
348
355
typedef enum TfLiteAllocationType {
349
356
kTfLiteMemNone = 0 ,
350
357
kTfLiteMmapRo ,
@@ -353,6 +360,7 @@ typedef enum TfLiteAllocationType {
353
360
kTfLiteDynamic ,
354
361
kTfLitePersistentRo ,
355
362
kTfLiteCustom ,
363
+ kTfLiteVariantObject ,
356
364
} TfLiteAllocationType ;
357
365
358
366
// The delegates should use zero or positive integers to represent handles.
@@ -959,12 +967,53 @@ typedef struct TfLiteRegistration {
959
967
// ops. We keep it inside of `TfLiteRegistration` and use it to route
960
968
// callbacks properly.
961
969
TfLiteRegistrationExternal * registration_external ;
970
+
971
+ // Retrieves asynchronous kernel.
972
+ //
973
+ // If the `async_kernel` field is nullptr, it means the operation described by
974
+ // this TfLiteRegistration object does not support asynchronous execution.
975
+ // Otherwise, the function that the field points to should only be called for
976
+ // delegate kernel nodes, i.e. `node` should be a delegate kernel node created
977
+ // by applying a delegate.
978
+ // If the function returns nullptr, that means that the underlying delegate
979
+ // does not support asynchronous execution for this `node`.
980
+ struct TfLiteAsyncKernel * (* async_kernel )(TfLiteContext * context ,
981
+ TfLiteNode * node );
962
982
} TfLiteRegistration ;
963
983
984
+ /// \private
964
985
// Old version of `TfLiteRegistration` to maintain binary backward
965
986
// compatibility.
966
- // WARNING: This structure is deprecated / not an official part of the API.
967
- // It should be only used for binary backward compatibility.
987
+ // The legacy registration type must be a POD struct type whose field types must
988
+ // be a prefix of the field types in TfLiteRegistration, and offset of the first
989
+ // field in TfLiteRegistration that is not present in the legacy registration
990
+ // type must be greater than or equal to the size of the legacy registration
991
+ // type.
992
+ // WARNING: This structure is deprecated / not an official part of the
993
+ // API. It should be only used for binary backward compatibility.
994
+ typedef struct TfLiteRegistration_V2 {
995
+ void * (* init )(TfLiteContext * context , const char * buffer , size_t length );
996
+ void (* free )(TfLiteContext * context , void * buffer );
997
+ TfLiteStatus (* prepare )(TfLiteContext * context , TfLiteNode * node );
998
+ TfLiteStatus (* invoke )(TfLiteContext * context , TfLiteNode * node );
999
+ const char * (* profiling_string )(const TfLiteContext * context ,
1000
+ const TfLiteNode * node );
1001
+ int32_t builtin_code ;
1002
+ const char * custom_name ;
1003
+ int version ;
1004
+ TfLiteRegistrationExternal * registration_external ;
1005
+ } TfLiteRegistration_V2 ;
1006
+
1007
+ /// \private
1008
+ // Old version of `TfLiteRegistration` to maintain binary backward
1009
+ // compatibility.
1010
+ // The legacy registration type must be a POD struct type whose field types must
1011
+ // be a prefix of the field types in TfLiteRegistration, and offset of the first
1012
+ // field in TfLiteRegistration that is not present in the legacy registration
1013
+ // type must be greater than or equal to the size of the legacy registration
1014
+ // type.
1015
+ // WARNING: This structure is deprecated / not an official part of the
1016
+ // API. It should be only used for binary backward compatibility.
968
1017
typedef struct TfLiteRegistration_V1 {
969
1018
void * (* init )(TfLiteContext * context , const char * buffer , size_t length );
970
1019
void (* free )(TfLiteContext * context , void * buffer );
@@ -1155,5 +1204,74 @@ void* TfLiteOpaqueDelegateGetData(const TfLiteOpaqueDelegate* delegate);
1155
1204
1156
1205
#ifdef __cplusplus
1157
1206
} // extern "C"
1207
+
1208
+ #include <utility>
1209
+
1210
+ // `kTfLiteVariant` type tensors encode arbitrary C++ objects behind their
1211
+ // `data.data : void*` member. This is the type-erased interface for interacting
1212
+ // with such objects at runtime. Deleting or Cloning any `VariantData`
1213
+ // will call the destructor and copy constructor of the erased type
1214
+ // automatically. For example usage, see `common_test.cc`.
1215
+ class VariantData {
1216
+ public :
1217
+ // All variant objects must be able to be destroyed and copied.
1218
+ virtual ~VariantData () = default ;
1219
+ // This allows for a "virtual copy-constructor" like pattern.
1220
+ // In most cases, we will be copying from an input to an output tensor.
1221
+ // Often, the output tensor is already allocated so we can pass
1222
+ // a pointer to its buffer for reuse.
1223
+ virtual VariantData * Clone (char * maybe_alloc ) const = 0 ;
1224
+ };
1225
+
1226
+ // An abstract base class for variant objects. The template parameter
1227
+ // is the type we are erasing.
1228
+ template < typename ErasedDerived >
1229
+ class AbstractVariantData : public VariantData {
1230
+ public :
1231
+ VariantData * Clone (char * maybe_alloc ) const override {
1232
+ if (maybe_alloc ) {
1233
+ // We assume that the output tensor is already a variant of the same
1234
+ // derived type. If the output is still allocated, then it still may have
1235
+ // state that was not destroyed, so we must call the destructor before
1236
+ // using the buffer.
1237
+ // This may actual have a non-negligle effect on perfomance if the
1238
+ // destructor is complex. In a future optimization we would want to
1239
+ // introduce something like "move to" semantics, allowing for the
1240
+ // underlying implementation to optimize for this case.
1241
+ reinterpret_cast < VariantData * > (maybe_alloc )-> ~VariantData ();
1242
+ return new (maybe_alloc )
1243
+ ErasedDerived (static_cast < ErasedDerived const & > (* this ));
1244
+ }
1245
+ return new ErasedDerived (static_cast < ErasedDerived const & > (* this ));
1246
+ }
1247
+
1248
+ protected :
1249
+ AbstractVariantData () = default ;
1250
+ AbstractVariantData (const AbstractVariantData & ) = default ;
1251
+ AbstractVariantData (AbstractVariantData && ) = delete ;
1252
+ };
1253
+
1254
+ // Analogous to `TfLiteTensorRealloc` for allocation of tensors whose
1255
+ // data member points to an arbitrary C++ object. `VariantType` refers
1256
+ // to the erased type of said object and `VariantArgs` refers to
1257
+ // a list of argument types with which to construct a new `VariantType`
1258
+ // `VariantArgs` must match constructor in `VariantType`.
1259
+ template < class VariantType , class ... VariantArgs >
1260
+ TfLiteStatus TfLiteTensorVariantRealloc (TfLiteTensor * t ,
1261
+ VariantArgs && ... args ) {
1262
+ if (t -> type != kTfLiteVariant ) return kTfLiteError ;
1263
+ if (t -> data .raw ) {
1264
+ reinterpret_cast < VariantData * > (t -> data .data )-> ~VariantData ();
1265
+ // For now we assume if `t` is already allocated then it was allocated
1266
+ // with the same `VariantType` as templated.
1267
+ t -> data .data =
1268
+ new (t -> data .raw ) VariantType (std ::forward < VariantArgs ...> (args ...));
1269
+ } else {
1270
+ t -> data .data = new VariantType (std ::forward < VariantArgs ...> (args ...));
1271
+ }
1272
+ t -> allocation_type = kTfLiteVariantObject ;
1273
+ return kTfLiteOk ;
1274
+ }
1275
+
1158
1276
#endif // __cplusplus
1159
1277
#endif // TENSORFLOW_LITE_CORE_C_COMMON_H_
0 commit comments