@@ -192,24 +192,24 @@ class TensorTest {
192192 assertEquals(tensor.dtype(), DType .FLOAT )
193193
194194 assertThatThrownBy { tensor.dataAsByteArray }
195- .isInstanceOf(IllegalStateException ::class .java)
196- .hasMessage(" Tensor of type Tensor_float32 cannot return data as byte array." )
195+ .isInstanceOf(IllegalStateException ::class .java)
196+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as byte array." )
197197
198198 assertThatThrownBy { tensor.dataAsUnsignedByteArray }
199- .isInstanceOf(IllegalStateException ::class .java)
200- .hasMessage(" Tensor of type Tensor_float32 cannot return data as unsigned byte array." )
199+ .isInstanceOf(IllegalStateException ::class .java)
200+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as unsigned byte array." )
201201
202202 assertThatThrownBy { tensor.dataAsIntArray }
203- .isInstanceOf(IllegalStateException ::class .java)
204- .hasMessage(" Tensor of type Tensor_float32 cannot return data as int array." )
203+ .isInstanceOf(IllegalStateException ::class .java)
204+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as int array." )
205205
206206 assertThatThrownBy { tensor.dataAsDoubleArray }
207- .isInstanceOf(IllegalStateException ::class .java)
208- .hasMessage(" Tensor of type Tensor_float32 cannot return data as double array." )
207+ .isInstanceOf(IllegalStateException ::class .java)
208+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as double array." )
209209
210210 assertThatThrownBy { tensor.dataAsLongArray }
211- .isInstanceOf(IllegalStateException ::class .java)
212- .hasMessage(" Tensor of type Tensor_float32 cannot return data as long array." )
211+ .isInstanceOf(IllegalStateException ::class .java)
212+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as long array." )
213213 }
214214
215215 @Test
@@ -219,20 +219,20 @@ class TensorTest {
219219 val mismatchShape = longArrayOf(1 , 2 )
220220
221221 assertThatThrownBy { Tensor .fromBlob(null as FloatArray? , mismatchShape) }
222- .isInstanceOf(IllegalArgumentException ::class .java)
223- .hasMessage(" Data array must be not null" )
222+ .isInstanceOf(IllegalArgumentException ::class .java)
223+ .hasMessage(" Data array must be not null" )
224224
225225 assertThatThrownBy { Tensor .fromBlob(data, null ) }
226- .isInstanceOf(IllegalArgumentException ::class .java)
227- .hasMessage(" Shape must be not null" )
226+ .isInstanceOf(IllegalArgumentException ::class .java)
227+ .hasMessage(" Shape must be not null" )
228228
229229 assertThatThrownBy { Tensor .fromBlob(data, shapeWithNegativeValues) }
230- .isInstanceOf(IllegalArgumentException ::class .java)
231- .hasMessage(" Shape elements must be non negative" )
230+ .isInstanceOf(IllegalArgumentException ::class .java)
231+ .hasMessage(" Shape elements must be non negative" )
232232
233233 assertThatThrownBy { Tensor .fromBlob(data, mismatchShape) }
234- .isInstanceOf(IllegalArgumentException ::class .java)
235- .hasMessage(" Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]" )
234+ .isInstanceOf(IllegalArgumentException ::class .java)
235+ .hasMessage(" Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]" )
236236 }
237237
238238 @Test
@@ -274,4 +274,32 @@ class TensorTest {
274274 assertEquals(shape[i], deserShape[i])
275275 }
276276 }
277+
278+ @Test
279+ fun testOnes_DTypeIsFloat () {
280+ val shape = longArrayOf(2 , 2 )
281+ val tensor = Tensor .ones(shape, DType .FLOAT )
282+ val data = tensor.dataAsFloatArray
283+ assertEquals(DType .FLOAT , tensor.dtype())
284+ for (i in shape.indices) {
285+ assertEquals(shape[i], tensor.shape[i])
286+ }
287+ for (i in data.indices) {
288+ assertEquals(data[i], 1.0f , 1e- 5 .toFloat())
289+ }
290+ }
291+
292+ @Test
293+ fun testZeros_DTypeIsFloat () {
294+ val shape = longArrayOf(2 , 2 )
295+ val tensor = Tensor .zeros(shape, DType .FLOAT )
296+ val data = tensor.dataAsFloatArray
297+ assertEquals(DType .FLOAT , tensor.dtype())
298+ for (i in shape.indices) {
299+ assertEquals(shape[i], tensor.shape[i])
300+ }
301+ for (i in data.indices) {
302+ assertEquals(data[i], 0.0f , 1e- 5 .toFloat())
303+ }
304+ }
277305}
0 commit comments