@@ -38,6 +38,7 @@ limitations under the License.
3838/// "third_party/tensorflow/lite/c/common.h".
3939/// Only the TensorFlow Lite implementation itself should include this
4040/// file directly.
41+ // IWYU pragma: private, include "third_party/tensorflow/lite/c/common.h"
4142
4243#ifndef TENSORFLOW_LITE_CORE_C_COMMON_H_
4344#define TENSORFLOW_LITE_CORE_C_COMMON_H_
@@ -157,6 +158,10 @@ int TfLiteFloatArrayGetSizeInBytes(int size);
157158// This returns a pointer, that you must free using TfLiteFloatArrayFree().
158159TfLiteFloatArray * TfLiteFloatArrayCreate (int size );
159160
161+ // Create a copy of an array passed as `src`.
162+ // You are expected to free memory with TfLiteFloatArrayFree.
163+ TfLiteFloatArray * TfLiteFloatArrayCopy (const TfLiteFloatArray * src );
164+
160165// Free memory of array `a`.
161166void TfLiteFloatArrayFree (TfLiteFloatArray * a );
162167#endif // TF_LITE_STATIC_MEMORY
@@ -345,6 +350,8 @@ typedef union TfLitePtrUnion {
345350// as constant inputs for downstream ops (also in prepare).
346351// * kTfLiteCustom: Custom memory allocation provided by the user. See
347352// TfLiteCustomAllocation below.
353+ // * kTfLiteVariantObject: Allocation is an arbitrary type-erased C++ object.
354+ // Allocation and deallocation are done through `new` and `delete`.
348355typedef enum TfLiteAllocationType {
349356 kTfLiteMemNone = 0 ,
350357 kTfLiteMmapRo ,
@@ -353,6 +360,7 @@ typedef enum TfLiteAllocationType {
353360 kTfLiteDynamic ,
354361 kTfLitePersistentRo ,
355362 kTfLiteCustom ,
363+ kTfLiteVariantObject ,
356364} TfLiteAllocationType ;
357365
358366// The delegates should use zero or positive integers to represent handles.
@@ -959,12 +967,53 @@ typedef struct TfLiteRegistration {
959967 // ops. We keep it inside of `TfLiteRegistration` and use it to route
960968 // callbacks properly.
961969 TfLiteRegistrationExternal * registration_external ;
970+
971+ // Retrieves asynchronous kernel.
972+ //
973+ // If the `async_kernel` field is nullptr, it means the operation described by
974+ // this TfLiteRegistration object does not support asynchronous execution.
975+ // Otherwise, the function that the field points to should only be called for
976+ // delegate kernel nodes, i.e. `node` should be a delegate kernel node created
977+ // by applying a delegate.
978+ // If the function returns nullptr, that means that the underlying delegate
979+ // does not support asynchronous execution for this `node`.
980+ struct TfLiteAsyncKernel * (* async_kernel )(TfLiteContext * context ,
981+ TfLiteNode * node );
962982} TfLiteRegistration ;
963983
984+ /// \private
964985// Old version of `TfLiteRegistration` to maintain binary backward
965986// compatibility.
966- // WARNING: This structure is deprecated / not an official part of the API.
967- // It should be only used for binary backward compatibility.
987+ // The legacy registration type must be a POD struct type whose field types must
988+ // be a prefix of the field types in TfLiteRegistration, and offset of the first
989+ // field in TfLiteRegistration that is not present in the legacy registration
990+ // type must be greater than or equal to the size of the legacy registration
991+ // type.
992+ // WARNING: This structure is deprecated / not an official part of the
993+ // API. It should be only used for binary backward compatibility.
994+ typedef struct TfLiteRegistration_V2 {
995+ void * (* init )(TfLiteContext * context , const char * buffer , size_t length );
996+ void (* free )(TfLiteContext * context , void * buffer );
997+ TfLiteStatus (* prepare )(TfLiteContext * context , TfLiteNode * node );
998+ TfLiteStatus (* invoke )(TfLiteContext * context , TfLiteNode * node );
999+ const char * (* profiling_string )(const TfLiteContext * context ,
1000+ const TfLiteNode * node );
1001+ int32_t builtin_code ;
1002+ const char * custom_name ;
1003+ int version ;
1004+ TfLiteRegistrationExternal * registration_external ;
1005+ } TfLiteRegistration_V2 ;
1006+
1007+ /// \private
1008+ // Old version of `TfLiteRegistration` to maintain binary backward
1009+ // compatibility.
1010+ // The legacy registration type must be a POD struct type whose field types must
1011+ // be a prefix of the field types in TfLiteRegistration, and offset of the first
1012+ // field in TfLiteRegistration that is not present in the legacy registration
1013+ // type must be greater than or equal to the size of the legacy registration
1014+ // type.
1015+ // WARNING: This structure is deprecated / not an official part of the
1016+ // API. It should be only used for binary backward compatibility.
9681017typedef struct TfLiteRegistration_V1 {
9691018 void * (* init )(TfLiteContext * context , const char * buffer , size_t length );
9701019 void (* free )(TfLiteContext * context , void * buffer );
@@ -1155,5 +1204,74 @@ void* TfLiteOpaqueDelegateGetData(const TfLiteOpaqueDelegate* delegate);
11551204
11561205#ifdef __cplusplus
11571206} // extern "C"
1207+
1208+ #include <utility>
1209+
1210+ // `kTfLiteVariant` type tensors encode arbitrary C++ objects behind their
1211+ // `data.data : void*` member. This is the type-erased interface for interacting
1212+ // with such objects at runtime. Deleting or Cloning any `VariantData`
1213+ // will call the destructor and copy constructor of the erased type
1214+ // automatically. For example usage, see `common_test.cc`.
1215+ class VariantData {
1216+ public :
1217+ // All variant objects must be able to be destroyed and copied.
1218+ virtual ~VariantData () = default ;
1219+ // This allows for a "virtual copy-constructor" like pattern.
1220+ // In most cases, we will be copying from an input to an output tensor.
1221+ // Often, the output tensor is already allocated so we can pass
1222+ // a pointer to its buffer for reuse.
1223+ virtual VariantData * Clone (char * maybe_alloc ) const = 0 ;
1224+ };
1225+
1226+ // An abstract base class for variant objects. The template parameter
1227+ // is the type we are erasing.
1228+ template < typename ErasedDerived >
1229+ class AbstractVariantData : public VariantData {
1230+ public :
1231+ VariantData * Clone (char * maybe_alloc ) const override {
1232+ if (maybe_alloc ) {
1233+ // We assume that the output tensor is already a variant of the same
1234+ // derived type. If the output is still allocated, then it still may have
1235+ // state that was not destroyed, so we must call the destructor before
1236+ // using the buffer.
1237+ // This may actual have a non-negligle effect on perfomance if the
1238+ // destructor is complex. In a future optimization we would want to
1239+ // introduce something like "move to" semantics, allowing for the
1240+ // underlying implementation to optimize for this case.
1241+ reinterpret_cast < VariantData * > (maybe_alloc )-> ~VariantData ();
1242+ return new (maybe_alloc )
1243+ ErasedDerived (static_cast < ErasedDerived const & > (* this ));
1244+ }
1245+ return new ErasedDerived (static_cast < ErasedDerived const & > (* this ));
1246+ }
1247+
1248+ protected :
1249+ AbstractVariantData () = default ;
1250+ AbstractVariantData (const AbstractVariantData & ) = default ;
1251+ AbstractVariantData (AbstractVariantData && ) = delete ;
1252+ };
1253+
1254+ // Analogous to `TfLiteTensorRealloc` for allocation of tensors whose
1255+ // data member points to an arbitrary C++ object. `VariantType` refers
1256+ // to the erased type of said object and `VariantArgs` refers to
1257+ // a list of argument types with which to construct a new `VariantType`
1258+ // `VariantArgs` must match constructor in `VariantType`.
1259+ template < class VariantType , class ... VariantArgs >
1260+ TfLiteStatus TfLiteTensorVariantRealloc (TfLiteTensor * t ,
1261+ VariantArgs && ... args ) {
1262+ if (t -> type != kTfLiteVariant ) return kTfLiteError ;
1263+ if (t -> data .raw ) {
1264+ reinterpret_cast < VariantData * > (t -> data .data )-> ~VariantData ();
1265+ // For now we assume if `t` is already allocated then it was allocated
1266+ // with the same `VariantType` as templated.
1267+ t -> data .data =
1268+ new (t -> data .raw ) VariantType (std ::forward < VariantArgs ...> (args ...));
1269+ } else {
1270+ t -> data .data = new VariantType (std ::forward < VariantArgs ...> (args ...));
1271+ }
1272+ t -> allocation_type = kTfLiteVariantObject ;
1273+ return kTfLiteOk ;
1274+ }
1275+
11581276#endif // __cplusplus
11591277#endif // TENSORFLOW_LITE_CORE_C_COMMON_H_
0 commit comments