@@ -341,6 +341,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
341341 return new Tensor_float64 (data , shape );
342342 }
343343
344+ /**
345+ *
346+ * Creates a new Tensor instance with given data-type and all elements initialized to one.
347+ *
348+ * @param shape Tensor shape
349+ * @param dtype Tensor data-type
350+ */
351+ public static Tensor ones (long [] shape , DType dtype ) {
352+ checkArgument (shape != null , ERROR_MSG_SHAPE_NOT_NULL );
353+ checkShape (shape );
354+ int numElements = (int ) numel (shape );
355+ switch (dtype ) {
356+ case UINT8 :
357+ byte [] uInt8Data = new byte [numElements ];
358+ Arrays .fill (uInt8Data , (byte ) 1 );
359+ return Tensor .fromBlobUnsigned (uInt8Data , shape );
360+ case INT8 :
361+ byte [] int8Data = new byte [numElements ];
362+ Arrays .fill (int8Data , (byte ) 1 );
363+ return Tensor .fromBlob (int8Data , shape );
364+ case INT32 :
365+ int [] int32Data = new int [numElements ];
366+ Arrays .fill (int32Data , 1 );
367+ return Tensor .fromBlob (int32Data , shape );
368+ case FLOAT :
369+ float [] float32Data = new float [numElements ];
370+ Arrays .fill (float32Data , 1.0f );
371+ return Tensor .fromBlob (float32Data , shape );
372+ case INT64 :
373+ long [] int64Data = new long [numElements ];
374+ Arrays .fill (int64Data , 1L );
375+ return Tensor .fromBlob (int64Data , shape );
376+ case DOUBLE :
377+ double [] float64Data = new double [numElements ];
378+ Arrays .fill (float64Data , 1.0 );
379+ return Tensor .fromBlob (float64Data , shape );
380+ default :
381+ throw new IllegalArgumentException (String .format ("Tensor.ones() cannot be used with DType %s" , dtype ));
382+ }
383+ }
384+
385+ /**
386+ *
387+ * Creates a new Tensor instance with given data-type and all elements initialized to zero.
388+ *
389+ * @param shape Tensor shape
390+ * @param dtype Tensor data-type
391+ */
392+ public static Tensor zeros (long [] shape , DType dtype ) {
393+ checkArgument (shape != null , ERROR_MSG_SHAPE_NOT_NULL );
394+ checkShape (shape );
395+ int numElements = (int ) numel (shape );
396+ switch (dtype ) {
397+ case UINT8 :
398+ byte [] uInt8Data = new byte [numElements ];
399+ return Tensor .fromBlobUnsigned (uInt8Data , shape );
400+ case INT8 :
401+ byte [] int8Data = new byte [numElements ];
402+ return Tensor .fromBlob (int8Data , shape );
403+ case INT32 :
404+ int [] int32Data = new int [numElements ];
405+ return Tensor .fromBlob (int32Data , shape );
406+ case FLOAT :
407+ float [] float32Data = new float [numElements ];
408+ return Tensor .fromBlob (float32Data , shape );
409+ case INT64 :
410+ long [] int64Data = new long [numElements ];
411+ return Tensor .fromBlob (int64Data , shape );
412+ case DOUBLE :
413+ double [] float64Data = new double [numElements ];
414+ return Tensor .fromBlob (float64Data , shape );
415+ default :
416+ throw new IllegalArgumentException (String .format ("Tensor.zeros() cannot be used with DType %s" , dtype ));
417+ }
418+ }
419+
344420 @ DoNotStrip private HybridData mHybridData ;
345421
346422 private Tensor (long [] shape ) {
0 commit comments