@@ -41,7 +41,7 @@ impl Tensor {
41
41
}
42
42
43
43
/// (Re)Set the shape of the tensor to a new shape.
44
- pub fn set_shape ( & self , shape : Shape ) -> Result < Self > {
44
+ pub fn set_shape ( & self , shape : & Shape ) -> Result < Self > {
45
45
try_unsafe ! ( ov_tensor_set_shape( self . ptr, shape. as_c_struct( ) ) ) ?;
46
46
Ok ( Self { ptr : self . ptr } )
47
47
}
@@ -57,6 +57,10 @@ impl Tensor {
57
57
}
58
58
59
59
/// Get the data type of elements of the tensor.
60
+ ///
61
+ /// # Panics
62
+ ///
63
+ /// This function panics in the unlikely case OpenVINO returns an unknown element type.
60
64
pub fn get_element_type ( & self ) -> Result < ElementType > {
61
65
let mut element_type = ElementType :: Undefined as u32 ;
62
66
try_unsafe ! ( ov_tensor_get_element_type(
@@ -135,12 +139,14 @@ impl Tensor {
135
139
/// Convenience function for checking that we can cast `data` to a slice of `T`, returning the
136
140
/// length of that slice.
137
141
fn get_safe_len < T > ( data : & [ u8 ] ) -> usize {
138
- if data. len ( ) % std:: mem:: size_of :: < T > ( ) != 0 {
139
- panic ! ( "data size is not a multiple of the size of `T`" ) ;
140
- }
141
- if data. as_ptr ( ) as usize % std:: mem:: align_of :: < T > ( ) != 0 {
142
- panic ! ( "raw data is not aligned to `T`'s alignment" ) ;
143
- }
142
+ assert ! (
143
+ data. len( ) % std:: mem:: size_of:: <T >( ) == 0 ,
144
+ "data size is not a multiple of the size of `T`"
145
+ ) ;
146
+ assert ! (
147
+ data. as_ptr( ) as usize % std:: mem:: align_of:: <T >( ) == 0 ,
148
+ "raw data is not aligned to `T`'s alignment"
149
+ ) ;
144
150
data. len ( ) / std:: mem:: size_of :: < T > ( )
145
151
}
146
152
@@ -151,66 +157,51 @@ mod tests {
151
157
#[ test]
152
158
fn test_create_tensor ( ) {
153
159
openvino_sys:: library:: load ( ) . unwrap ( ) ;
154
- let shape = Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ;
160
+ let shape = Shape :: new ( & [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ;
155
161
let tensor = Tensor :: new ( ElementType :: F32 , & shape) . unwrap ( ) ;
156
162
assert ! ( !tensor. ptr. is_null( ) ) ;
157
163
}
158
164
159
165
#[ test]
160
166
fn test_get_shape ( ) {
161
167
openvino_sys:: library:: load ( ) . unwrap ( ) ;
162
- let tensor = Tensor :: new (
163
- ElementType :: F32 ,
164
- & Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
165
- )
166
- . unwrap ( ) ;
168
+ let tensor =
169
+ Tensor :: new ( ElementType :: F32 , & Shape :: new ( & [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ) . unwrap ( ) ;
167
170
let shape = tensor. get_shape ( ) . unwrap ( ) ;
168
171
assert_eq ! ( shape. get_rank( ) , 4 ) ;
169
172
}
170
173
171
174
#[ test]
172
175
fn test_get_element_type ( ) {
173
176
openvino_sys:: library:: load ( ) . unwrap ( ) ;
174
- let tensor = Tensor :: new (
175
- ElementType :: F32 ,
176
- & Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
177
- )
178
- . unwrap ( ) ;
177
+ let tensor =
178
+ Tensor :: new ( ElementType :: F32 , & Shape :: new ( & [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ) . unwrap ( ) ;
179
179
let element_type = tensor. get_element_type ( ) . unwrap ( ) ;
180
180
assert_eq ! ( element_type, ElementType :: F32 ) ;
181
181
}
182
182
183
183
#[ test]
184
184
fn test_get_size ( ) {
185
185
openvino_sys:: library:: load ( ) . unwrap ( ) ;
186
- let tensor = Tensor :: new (
187
- ElementType :: F32 ,
188
- & Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
189
- )
190
- . unwrap ( ) ;
186
+ let tensor =
187
+ Tensor :: new ( ElementType :: F32 , & Shape :: new ( & [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ) . unwrap ( ) ;
191
188
let size = tensor. get_size ( ) . unwrap ( ) ;
192
- assert_eq ! ( size, 1 * 3 * 227 * 227 ) ;
189
+ assert_eq ! ( size, 3 * 227 * 227 ) ;
193
190
}
194
191
195
192
#[ test]
196
193
fn test_get_byte_size ( ) {
197
194
openvino_sys:: library:: load ( ) . unwrap ( ) ;
198
- let tensor = Tensor :: new (
199
- ElementType :: F32 ,
200
- & Shape :: new ( & vec ! [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ,
201
- )
202
- . unwrap ( ) ;
195
+ let tensor =
196
+ Tensor :: new ( ElementType :: F32 , & Shape :: new ( & [ 1 , 3 , 227 , 227 ] ) . unwrap ( ) ) . unwrap ( ) ;
203
197
let byte_size = tensor. get_byte_size ( ) . unwrap ( ) ;
204
- assert_eq ! (
205
- byte_size,
206
- 1 * 3 * 227 * 227 * std:: mem:: size_of:: <f32 >( ) as usize
207
- ) ;
198
+ assert_eq ! ( byte_size, 3 * 227 * 227 * std:: mem:: size_of:: <f32 >( ) ) ;
208
199
}
209
200
210
201
#[ test]
211
202
fn casting ( ) {
212
203
openvino_sys:: library:: load ( ) . unwrap ( ) ;
213
- let shape = Shape :: new ( & vec ! [ 10 , 10 , 10 ] ) . unwrap ( ) ;
204
+ let shape = Shape :: new ( & [ 10 , 10 , 10 ] ) . unwrap ( ) ;
214
205
let tensor = Tensor :: new ( ElementType :: F32 , & shape) . unwrap ( ) ;
215
206
let data = tensor. get_data :: < f32 > ( ) . unwrap ( ) ;
216
207
assert_eq ! ( data. len( ) , 10 * 10 * 10 ) ;
@@ -220,7 +211,7 @@ mod tests {
220
211
#[ should_panic( expected = "data size is not a multiple of the size of `T`" ) ]
221
212
fn casting_check ( ) {
222
213
openvino_sys:: library:: load ( ) . unwrap ( ) ;
223
- let shape = Shape :: new ( & vec ! [ 10 , 10 , 10 ] ) . unwrap ( ) ;
214
+ let shape = Shape :: new ( & [ 10 , 10 , 10 ] ) . unwrap ( ) ;
224
215
let tensor = Tensor :: new ( ElementType :: F32 , & shape) . unwrap ( ) ;
225
216
#[ allow( dead_code) ]
226
217
struct LargeOddType ( [ u8 ; 1061 ] ) ;
0 commit comments