Skip to content

Commit f92565f

Browse files
authored
Android: Add Tensor.ones() and Tensor.zeros() factory methods (#15388)
Fixes #15125 Release notes: `misc` ### Summary - [x] Add implementation and test for a new factory method `Tensor.ones()` in the `android` extension ### Test plan The following new JUnit tests were added to verify the behavior of `Tensor.ones()` in `extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt` - `testOnes_DTypeIsFloat` cc @kirklandsign @cbilgin
1 parent be741ad commit f92565f

File tree

2 files changed

+104
-0
lines changed
  • extension/android/executorch_android/src

2 files changed

+104
-0
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
392392
return new Tensor_float64(data, shape);
393393
}
394394

395+
/**
396+
* Creates a new Tensor instance with given data-type and all elements initialized to one.
397+
*
398+
* @param shape Tensor shape
399+
* @param dtype Tensor data-type
400+
*/
401+
public static Tensor ones(long[] shape, DType dtype) {
402+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
403+
checkShape(shape);
404+
int numElements = (int) numel(shape);
405+
switch (dtype) {
406+
case UINT8:
407+
byte[] uInt8Data = new byte[numElements];
408+
Arrays.fill(uInt8Data, (byte) 1);
409+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
410+
case INT8:
411+
byte[] int8Data = new byte[numElements];
412+
Arrays.fill(int8Data, (byte) 1);
413+
return Tensor.fromBlob(int8Data, shape);
414+
case INT32:
415+
int[] int32Data = new int[numElements];
416+
Arrays.fill(int32Data, 1);
417+
return Tensor.fromBlob(int32Data, shape);
418+
case FLOAT:
419+
float[] float32Data = new float[numElements];
420+
Arrays.fill(float32Data, 1.0f);
421+
return Tensor.fromBlob(float32Data, shape);
422+
case INT64:
423+
long[] int64Data = new long[numElements];
424+
Arrays.fill(int64Data, 1L);
425+
return Tensor.fromBlob(int64Data, shape);
426+
case DOUBLE:
427+
double[] float64Data = new double[numElements];
428+
Arrays.fill(float64Data, 1.0);
429+
return Tensor.fromBlob(float64Data, shape);
430+
default:
431+
throw new IllegalArgumentException(
432+
String.format("Tensor.ones() cannot be used with DType %s", dtype));
433+
}
434+
}
435+
436+
/**
437+
* Creates a new Tensor instance with given data-type and all elements initialized to zero.
438+
*
439+
* @param shape Tensor shape
440+
* @param dtype Tensor data-type
441+
*/
442+
public static Tensor zeros(long[] shape, DType dtype) {
443+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
444+
checkShape(shape);
445+
int numElements = (int) numel(shape);
446+
switch (dtype) {
447+
case UINT8:
448+
byte[] uInt8Data = new byte[numElements];
449+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
450+
case INT8:
451+
byte[] int8Data = new byte[numElements];
452+
return Tensor.fromBlob(int8Data, shape);
453+
case INT32:
454+
int[] int32Data = new int[numElements];
455+
return Tensor.fromBlob(int32Data, shape);
456+
case FLOAT:
457+
float[] float32Data = new float[numElements];
458+
return Tensor.fromBlob(float32Data, shape);
459+
case INT64:
460+
long[] int64Data = new long[numElements];
461+
return Tensor.fromBlob(int64Data, shape);
462+
case DOUBLE:
463+
double[] float64Data = new double[numElements];
464+
return Tensor.fromBlob(float64Data, shape);
465+
default:
466+
throw new IllegalArgumentException(
467+
String.format("Tensor.zeros() cannot be used with DType %s", dtype));
468+
}
469+
}
470+
395471
@DoNotStrip private HybridData mHybridData;
396472

397473
private Tensor(long[] shape) {

extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,4 +336,32 @@ class TensorTest {
336336
assertEquals(shape[i], deserShape[i])
337337
}
338338
}
339+
340+
@Test
341+
fun testOnes_DTypeIsFloat() {
342+
val shape = longArrayOf(2, 2)
343+
val tensor = Tensor.ones(shape, DType.FLOAT)
344+
val data = tensor.dataAsFloatArray
345+
assertEquals(DType.FLOAT, tensor.dtype())
346+
for (i in shape.indices) {
347+
assertEquals(shape[i], tensor.shape[i])
348+
}
349+
for (i in data.indices) {
350+
assertEquals(data[i], 1.0f, 1e-5.toFloat())
351+
}
352+
}
353+
354+
@Test
355+
fun testZeros_DTypeIsFloat() {
356+
val shape = longArrayOf(2, 2)
357+
val tensor = Tensor.zeros(shape, DType.FLOAT)
358+
val data = tensor.dataAsFloatArray
359+
assertEquals(DType.FLOAT, tensor.dtype())
360+
for (i in shape.indices) {
361+
assertEquals(shape[i], tensor.shape[i])
362+
}
363+
for (i in data.indices) {
364+
assertEquals(data[i], 0.0f, 1e-5.toFloat())
365+
}
366+
}
339367
}

0 commit comments

Comments
 (0)