@@ -392,6 +392,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
392392 return new Tensor_float64 (data , shape );
393393 }
394394
395+ /**
396+ * Creates a new Tensor instance with given data-type and all elements initialized to one.
397+ *
398+ * @param shape Tensor shape
399+ * @param dtype Tensor data-type
400+ */
401+ public static Tensor ones (long [] shape , DType dtype ) {
402+ checkArgument (shape != null , ERROR_MSG_SHAPE_NOT_NULL );
403+ checkShape (shape );
404+ int numElements = (int ) numel (shape );
405+ switch (dtype ) {
406+ case UINT8 :
407+ byte [] uInt8Data = new byte [numElements ];
408+ Arrays .fill (uInt8Data , (byte ) 1 );
409+ return Tensor .fromBlobUnsigned (uInt8Data , shape );
410+ case INT8 :
411+ byte [] int8Data = new byte [numElements ];
412+ Arrays .fill (int8Data , (byte ) 1 );
413+ return Tensor .fromBlob (int8Data , shape );
414+ case INT32 :
415+ int [] int32Data = new int [numElements ];
416+ Arrays .fill (int32Data , 1 );
417+ return Tensor .fromBlob (int32Data , shape );
418+ case FLOAT :
419+ float [] float32Data = new float [numElements ];
420+ Arrays .fill (float32Data , 1.0f );
421+ return Tensor .fromBlob (float32Data , shape );
422+ case INT64 :
423+ long [] int64Data = new long [numElements ];
424+ Arrays .fill (int64Data , 1L );
425+ return Tensor .fromBlob (int64Data , shape );
426+ case DOUBLE :
427+ double [] float64Data = new double [numElements ];
428+ Arrays .fill (float64Data , 1.0 );
429+ return Tensor .fromBlob (float64Data , shape );
430+ default :
431+ throw new IllegalArgumentException (
432+ String .format ("Tensor.ones() cannot be used with DType %s" , dtype ));
433+ }
434+ }
435+
436+ /**
437+ * Creates a new Tensor instance with given data-type and all elements initialized to zero.
438+ *
439+ * @param shape Tensor shape
440+ * @param dtype Tensor data-type
441+ */
442+ public static Tensor zeros (long [] shape , DType dtype ) {
443+ checkArgument (shape != null , ERROR_MSG_SHAPE_NOT_NULL );
444+ checkShape (shape );
445+ int numElements = (int ) numel (shape );
446+ switch (dtype ) {
447+ case UINT8 :
448+ byte [] uInt8Data = new byte [numElements ];
449+ return Tensor .fromBlobUnsigned (uInt8Data , shape );
450+ case INT8 :
451+ byte [] int8Data = new byte [numElements ];
452+ return Tensor .fromBlob (int8Data , shape );
453+ case INT32 :
454+ int [] int32Data = new int [numElements ];
455+ return Tensor .fromBlob (int32Data , shape );
456+ case FLOAT :
457+ float [] float32Data = new float [numElements ];
458+ return Tensor .fromBlob (float32Data , shape );
459+ case INT64 :
460+ long [] int64Data = new long [numElements ];
461+ return Tensor .fromBlob (int64Data , shape );
462+ case DOUBLE :
463+ double [] float64Data = new double [numElements ];
464+ return Tensor .fromBlob (float64Data , shape );
465+ default :
466+ throw new IllegalArgumentException (
467+ String .format ("Tensor.zeros() cannot be used with DType %s" , dtype ));
468+ }
469+ }
470+
395471 @ DoNotStrip private HybridData mHybridData ;
396472
397473 private Tensor (long [] shape ) {
0 commit comments