Skip to content

Commit 4c6d19d

Browse files
committed
Android: Add Tensor.ones() and Tensor.zeros() factory method to create tensors initialized with ones and zeros resp. (#15125)
1 parent 7ce78c0 commit 4c6d19d

File tree

2 files changed

+104
-0
lines changed
  • extension/android/executorch_android/src

2 files changed

+104
-0
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
341341
return new Tensor_float64(data, shape);
342342
}
343343

344+
/**
345+
*
346+
* Creates a new Tensor instance with given data-type and all elements initialized to one.
347+
*
348+
* @param shape Tensor shape
349+
* @param dtype Tensor data-type
350+
*/
351+
public static Tensor ones(long[] shape, DType dtype) {
352+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
353+
checkShape(shape);
354+
int numElements = (int) numel(shape);
355+
switch (dtype) {
356+
case UINT8:
357+
byte[] uInt8Data = new byte[numElements];
358+
Arrays.fill(uInt8Data, (byte) 1);
359+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
360+
case INT8:
361+
byte[] int8Data = new byte[numElements];
362+
Arrays.fill(int8Data, (byte) 1);
363+
return Tensor.fromBlob(int8Data, shape);
364+
case INT32:
365+
int[] int32Data = new int[numElements];
366+
Arrays.fill(int32Data, 1);
367+
return Tensor.fromBlob(int32Data, shape);
368+
case FLOAT:
369+
float[] float32Data = new float[numElements];
370+
Arrays.fill(float32Data, 1.0f);
371+
return Tensor.fromBlob(float32Data, shape);
372+
case INT64:
373+
long[] int64Data = new long[numElements];
374+
Arrays.fill(int64Data, 1L);
375+
return Tensor.fromBlob(int64Data, shape);
376+
case DOUBLE:
377+
double[] float64Data = new double[numElements];
378+
Arrays.fill(float64Data, 1.0);
379+
return Tensor.fromBlob(float64Data, shape);
380+
default:
381+
throw new IllegalArgumentException(String.format("Tensor.ones() cannot be used with DType %s", dtype));
382+
}
383+
}
384+
385+
/**
386+
*
387+
* Creates a new Tensor instance with given data-type and all elements initialized to zero.
388+
*
389+
* @param shape Tensor shape
390+
* @param dtype Tensor data-type
391+
*/
392+
public static Tensor zeros(long[] shape, DType dtype) {
393+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
394+
checkShape(shape);
395+
int numElements = (int) numel(shape);
396+
switch (dtype) {
397+
case UINT8:
398+
byte[] uInt8Data = new byte[numElements];
399+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
400+
case INT8:
401+
byte[] int8Data = new byte[numElements];
402+
return Tensor.fromBlob(int8Data, shape);
403+
case INT32:
404+
int[] int32Data = new int[numElements];
405+
return Tensor.fromBlob(int32Data, shape);
406+
case FLOAT:
407+
float[] float32Data = new float[numElements];
408+
return Tensor.fromBlob(float32Data, shape);
409+
case INT64:
410+
long[] int64Data = new long[numElements];
411+
return Tensor.fromBlob(int64Data, shape);
412+
case DOUBLE:
413+
double[] float64Data = new double[numElements];
414+
return Tensor.fromBlob(float64Data, shape);
415+
default:
416+
throw new IllegalArgumentException(String.format("Tensor.zeros() cannot be used with DType %s", dtype));
417+
}
418+
}
419+
344420
@DoNotStrip private HybridData mHybridData;
345421

346422
private Tensor(long[] shape) {

extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,32 @@ class TensorTest {
274274
assertEquals(shape[i], deserShape[i])
275275
}
276276
}
277+
278+
@Test
279+
fun testOnes_DTypeIsFloat() {
280+
val shape = longArrayOf(2, 2)
281+
val tensor = Tensor.ones(shape, DType.FLOAT)
282+
val data = tensor.dataAsFloatArray
283+
assertEquals(DType.FLOAT, tensor.dtype())
284+
for (i in shape.indices) {
285+
assertEquals(shape[i], tensor.shape[i])
286+
}
287+
for (i in data.indices) {
288+
assertEquals(data[i], 1.0f, 1e-5.toFloat())
289+
}
290+
}
291+
292+
@Test
293+
fun testZeros_DTypeIsFloat() {
294+
val shape = longArrayOf(2, 2)
295+
val tensor = Tensor.zeros(shape, DType.FLOAT)
296+
val data = tensor.dataAsFloatArray
297+
assertEquals(DType.FLOAT, tensor.dtype())
298+
for (i in shape.indices) {
299+
assertEquals(shape[i], tensor.shape[i])
300+
}
301+
for (i in data.indices) {
302+
assertEquals(data[i], 0.0f, 1e-5.toFloat())
303+
}
304+
}
277305
}

0 commit comments

Comments
 (0)