Skip to content

Commit def5753

Browse files
committed
Android: Add Tensor.ones() and Tensor.zeros() factory method to create tensors initialized with ones and zeros resp. (#15125)
1 parent 7ce78c0 commit def5753

File tree

2 files changed

+122
-18
lines changed
  • extension/android/executorch_android/src

2 files changed

+122
-18
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
341341
return new Tensor_float64(data, shape);
342342
}
343343

344+
/**
345+
* Creates a new Tensor instance with given data-type and all elements initialized to one.
346+
*
347+
* @param shape Tensor shape
348+
* @param dtype Tensor data-type
349+
*/
350+
public static Tensor ones(long[] shape, DType dtype) {
351+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
352+
checkShape(shape);
353+
int numElements = (int) numel(shape);
354+
switch (dtype) {
355+
case UINT8:
356+
byte[] uInt8Data = new byte[numElements];
357+
Arrays.fill(uInt8Data, (byte) 1);
358+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
359+
case INT8:
360+
byte[] int8Data = new byte[numElements];
361+
Arrays.fill(int8Data, (byte) 1);
362+
return Tensor.fromBlob(int8Data, shape);
363+
case INT32:
364+
int[] int32Data = new int[numElements];
365+
Arrays.fill(int32Data, 1);
366+
return Tensor.fromBlob(int32Data, shape);
367+
case FLOAT:
368+
float[] float32Data = new float[numElements];
369+
Arrays.fill(float32Data, 1.0f);
370+
return Tensor.fromBlob(float32Data, shape);
371+
case INT64:
372+
long[] int64Data = new long[numElements];
373+
Arrays.fill(int64Data, 1L);
374+
return Tensor.fromBlob(int64Data, shape);
375+
case DOUBLE:
376+
double[] float64Data = new double[numElements];
377+
Arrays.fill(float64Data, 1.0);
378+
return Tensor.fromBlob(float64Data, shape);
379+
default:
380+
throw new IllegalArgumentException(
381+
String.format("Tensor.ones() cannot be used with DType %s", dtype));
382+
}
383+
}
384+
385+
/**
386+
* Creates a new Tensor instance with given data-type and all elements initialized to zero.
387+
*
388+
* @param shape Tensor shape
389+
* @param dtype Tensor data-type
390+
*/
391+
public static Tensor zeros(long[] shape, DType dtype) {
392+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
393+
checkShape(shape);
394+
int numElements = (int) numel(shape);
395+
switch (dtype) {
396+
case UINT8:
397+
byte[] uInt8Data = new byte[numElements];
398+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
399+
case INT8:
400+
byte[] int8Data = new byte[numElements];
401+
return Tensor.fromBlob(int8Data, shape);
402+
case INT32:
403+
int[] int32Data = new int[numElements];
404+
return Tensor.fromBlob(int32Data, shape);
405+
case FLOAT:
406+
float[] float32Data = new float[numElements];
407+
return Tensor.fromBlob(float32Data, shape);
408+
case INT64:
409+
long[] int64Data = new long[numElements];
410+
return Tensor.fromBlob(int64Data, shape);
411+
case DOUBLE:
412+
double[] float64Data = new double[numElements];
413+
return Tensor.fromBlob(float64Data, shape);
414+
default:
415+
throw new IllegalArgumentException(
416+
String.format("Tensor.zeros() cannot be used with DType %s", dtype));
417+
}
418+
}
419+
344420
@DoNotStrip private HybridData mHybridData;
345421

346422
private Tensor(long[] shape) {

extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -192,24 +192,24 @@ class TensorTest {
192192
assertEquals(tensor.dtype(), DType.FLOAT)
193193

194194
assertThatThrownBy { tensor.dataAsByteArray }
195-
.isInstanceOf(IllegalStateException::class.java)
196-
.hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.")
195+
.isInstanceOf(IllegalStateException::class.java)
196+
.hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.")
197197

198198
assertThatThrownBy { tensor.dataAsUnsignedByteArray }
199-
.isInstanceOf(IllegalStateException::class.java)
200-
.hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.")
199+
.isInstanceOf(IllegalStateException::class.java)
200+
.hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.")
201201

202202
assertThatThrownBy { tensor.dataAsIntArray }
203-
.isInstanceOf(IllegalStateException::class.java)
204-
.hasMessage("Tensor of type Tensor_float32 cannot return data as int array.")
203+
.isInstanceOf(IllegalStateException::class.java)
204+
.hasMessage("Tensor of type Tensor_float32 cannot return data as int array.")
205205

206206
assertThatThrownBy { tensor.dataAsDoubleArray }
207-
.isInstanceOf(IllegalStateException::class.java)
208-
.hasMessage("Tensor of type Tensor_float32 cannot return data as double array.")
207+
.isInstanceOf(IllegalStateException::class.java)
208+
.hasMessage("Tensor of type Tensor_float32 cannot return data as double array.")
209209

210210
assertThatThrownBy { tensor.dataAsLongArray }
211-
.isInstanceOf(IllegalStateException::class.java)
212-
.hasMessage("Tensor of type Tensor_float32 cannot return data as long array.")
211+
.isInstanceOf(IllegalStateException::class.java)
212+
.hasMessage("Tensor of type Tensor_float32 cannot return data as long array.")
213213
}
214214

215215
@Test
@@ -219,20 +219,20 @@ class TensorTest {
219219
val mismatchShape = longArrayOf(1, 2)
220220

221221
assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) }
222-
.isInstanceOf(IllegalArgumentException::class.java)
223-
.hasMessage("Data array must be not null")
222+
.isInstanceOf(IllegalArgumentException::class.java)
223+
.hasMessage("Data array must be not null")
224224

225225
assertThatThrownBy { Tensor.fromBlob(data, null) }
226-
.isInstanceOf(IllegalArgumentException::class.java)
227-
.hasMessage("Shape must be not null")
226+
.isInstanceOf(IllegalArgumentException::class.java)
227+
.hasMessage("Shape must be not null")
228228

229229
assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) }
230-
.isInstanceOf(IllegalArgumentException::class.java)
231-
.hasMessage("Shape elements must be non negative")
230+
.isInstanceOf(IllegalArgumentException::class.java)
231+
.hasMessage("Shape elements must be non negative")
232232

233233
assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) }
234-
.isInstanceOf(IllegalArgumentException::class.java)
235-
.hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]")
234+
.isInstanceOf(IllegalArgumentException::class.java)
235+
.hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]")
236236
}
237237

238238
@Test
@@ -274,4 +274,32 @@ class TensorTest {
274274
assertEquals(shape[i], deserShape[i])
275275
}
276276
}
277+
278+
@Test
279+
fun testOnes_DTypeIsFloat() {
280+
val shape = longArrayOf(2, 2)
281+
val tensor = Tensor.ones(shape, DType.FLOAT)
282+
val data = tensor.dataAsFloatArray
283+
assertEquals(DType.FLOAT, tensor.dtype())
284+
for (i in shape.indices) {
285+
assertEquals(shape[i], tensor.shape[i])
286+
}
287+
for (i in data.indices) {
288+
assertEquals(data[i], 1.0f, 1e-5.toFloat())
289+
}
290+
}
291+
292+
@Test
293+
fun testZeros_DTypeIsFloat() {
294+
val shape = longArrayOf(2, 2)
295+
val tensor = Tensor.zeros(shape, DType.FLOAT)
296+
val data = tensor.dataAsFloatArray
297+
assertEquals(DType.FLOAT, tensor.dtype())
298+
for (i in shape.indices) {
299+
assertEquals(shape[i], tensor.shape[i])
300+
}
301+
for (i in data.indices) {
302+
assertEquals(data[i], 0.0f, 1e-5.toFloat())
303+
}
304+
}
277305
}

0 commit comments

Comments
 (0)