@@ -343,8 +343,8 @@ pub enum TensorElementDataType {
343343 Int32 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt ,
344344 /// Signed 64-bit int, equivalent to Rust's `i64`
345345 Int64 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt ,
346- // // / String, equivalent to Rust's `String`
347- // String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
346+ /// String, equivalent to Rust's `String`
347+ String = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt ,
348348 // /// Boolean, equivalent to Rust's `bool`
349349 // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
350350 // /// 16-bit floating point, equivalent to Rust's `f16`
@@ -374,9 +374,7 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
374374 Int16 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 ,
375375 Int32 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 ,
376376 Int64 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 ,
377- // String => {
378- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
379- // }
377+ String => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING ,
380378 // Bool => {
381379 // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
382380 // }
@@ -402,15 +400,22 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
402400/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
403401pub trait TypeToTensorElementDataType {
404402 /// Return the ONNX type for a Rust type
405- fn tensor_element_data_type ( ) -> sys:: ONNXTensorElementDataType ;
403+ fn tensor_element_data_type ( ) -> TensorElementDataType ;
404+
405+ /// If the type is `String`, returns `Some` with utf8 contents, else `None`.
406+ fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > ;
406407}
407408
408409macro_rules! impl_type_trait {
409410 ( $type_: ty, $variant: ident) => {
410411 impl TypeToTensorElementDataType for $type_ {
411- fn tensor_element_data_type( ) -> sys :: ONNXTensorElementDataType {
412+ fn tensor_element_data_type( ) -> TensorElementDataType {
412413 // unsafe { std::mem::transmute(TensorElementDataType::$variant) }
413- TensorElementDataType :: $variant. into( )
414+ TensorElementDataType :: $variant
415+ }
416+
417+ fn try_utf8_bytes( & self ) -> Option <& [ u8 ] > {
418+ None
414419 }
415420 }
416421 } ;
@@ -423,7 +428,6 @@ impl_type_trait!(u16, Uint16);
423428impl_type_trait ! ( i16 , Int16 ) ;
424429impl_type_trait ! ( i32 , Int32 ) ;
425430impl_type_trait ! ( i64 , Int64 ) ;
426- // impl_type_trait!(String, String);
427431// impl_type_trait!(bool, Bool);
428432// impl_type_trait!(f16, Float16);
429433impl_type_trait ! ( f64 , Double ) ;
@@ -433,6 +437,39 @@ impl_type_trait!(u64, Uint64);
433437// impl_type_trait!(, Complex128);
434438// impl_type_trait!(, Bfloat16);
435439
440+ /// Adapter for common Rust string types to Onnx strings.
441+ ///
442+ /// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
443+ /// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
444+ /// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
445+ /// types (which might implement `AsRef<str>` at some point in the future).
446+ pub trait Utf8Data {
447+ /// Returns the utf8 contents.
448+ fn utf8_bytes ( & self ) -> & [ u8 ] ;
449+ }
450+
451+ impl Utf8Data for String {
452+ fn utf8_bytes ( & self ) -> & [ u8 ] {
453+ self . as_bytes ( )
454+ }
455+ }
456+
457+ impl < ' a > Utf8Data for & ' a str {
458+ fn utf8_bytes ( & self ) -> & [ u8 ] {
459+ self . as_bytes ( )
460+ }
461+ }
462+
463+ impl < T : Utf8Data > TypeToTensorElementDataType for T {
464+ fn tensor_element_data_type ( ) -> TensorElementDataType {
465+ TensorElementDataType :: String
466+ }
467+
468+ fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > {
469+ Some ( self . utf8_bytes ( ) )
470+ }
471+ }
472+
436473/// Allocator type
437474#[ derive( Debug , Clone ) ]
438475#[ repr( i32 ) ]
0 commit comments