Skip to content

Commit 28959cf

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Helpers to create random integer tensor.
Summary: #8366 Reviewed By: kirklandsign Differential Revision: D74020940
1 parent 74dbf15 commit 28959cf

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,4 +980,111 @@ __attribute__((deprecated("This API is experimental.")))
980980

981981
@end
982982

983+
#pragma mark - RandomInteger Category
984+
985+
@interface ExecuTorchTensor (RandomInteger)
986+
987+
/**
988+
* Creates a tensor with random integer values in the specified range,
989+
* with full specification of shape, strides, data type, and shape dynamism.
990+
*
991+
* @param low An NSInteger specifying the inclusive lower bound of random values.
992+
* @param high An NSInteger specifying the exclusive upper bound of random values.
993+
* @param shape An NSArray of NSNumber objects representing the desired shape.
994+
* @param strides An NSArray of NSNumber objects representing the desired strides.
995+
* @param dataType An ExecuTorchDataType value specifying the element type.
996+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
997+
* @return A new ExecuTorchTensor instance filled with random integer values.
998+
*/
999+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
1000+
high:(NSInteger)high
1001+
shape:(NSArray<NSNumber *> *)shape
1002+
strides:(NSArray<NSNumber *> *)strides
1003+
dataType:(ExecuTorchDataType)dataType
1004+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1005+
NS_SWIFT_NAME(randint(low:high:shape:strides:dataType:shapeDynamism:));
1006+
1007+
/**
1008+
* Creates a tensor with random integer values in the specified range,
1009+
* with the given shape and data type.
1010+
*
1011+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1012+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1013+
* @param shape An NSArray of NSNumber objects representing the desired shape.
1014+
* @param dataType An ExecuTorchDataType value specifying the element type.
1015+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
1016+
* @return A new ExecuTorchTensor instance filled with random integer values.
1017+
*/
1018+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
1019+
high:(NSInteger)high
1020+
shape:(NSArray<NSNumber *> *)shape
1021+
dataType:(ExecuTorchDataType)dataType
1022+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1023+
NS_SWIFT_NAME(randint(low:high:shape:dataType:shapeDynamism:));
1024+
1025+
/**
1026+
* Creates a tensor with random integer values in the specified range,
1027+
* with the given shape (using dynamic bound shape) and data type.
1028+
*
1029+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1030+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1031+
* @param shape An NSArray of NSNumber objects representing the desired shape.
1032+
* @param dataType An ExecuTorchDataType value specifying the element type.
1033+
* @return A new ExecuTorchTensor instance filled with random integer values.
1034+
*/
1035+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
1036+
high:(NSInteger)high
1037+
shape:(NSArray<NSNumber *> *)shape
1038+
dataType:(ExecuTorchDataType)dataType
1039+
NS_SWIFT_NAME(randint(low:high:shape:dataType:));
1040+
1041+
/**
1042+
* Creates a tensor with random integer values in the specified range, similar to an existing tensor,
1043+
* with the given data type and shape dynamism.
1044+
*
1045+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
1046+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1047+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1048+
* @param dataType An ExecuTorchDataType value specifying the element type.
1049+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
1050+
* @return A new ExecuTorchTensor instance filled with random integer values.
1051+
*/
1052+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
1053+
low:(NSInteger)low
1054+
high:(NSInteger)high
1055+
dataType:(ExecuTorchDataType)dataType
1056+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1057+
NS_SWIFT_NAME(randint(like:low:high:dataType:shapeDynamism:));
1058+
1059+
/**
1060+
* Creates a tensor with random integer values in the specified range, similar to an existing tensor,
1061+
* with the given data type.
1062+
*
1063+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
1064+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1065+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1066+
* @param dataType An ExecuTorchDataType value specifying the element type.
1067+
* @return A new ExecuTorchTensor instance filled with random integer values.
1068+
*/
1069+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
1070+
low:(NSInteger)low
1071+
high:(NSInteger)high
1072+
dataType:(ExecuTorchDataType)dataType
1073+
NS_SWIFT_NAME(randint(like:low:high:dataType:));
1074+
1075+
/**
1076+
* Creates a tensor with random integer values in the specified range, similar to an existing tensor.
1077+
*
1078+
* @param tensor An existing ExecuTorchTensor instance.
1079+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1080+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1081+
* @return A new ExecuTorchTensor instance filled with random integer values.
1082+
*/
1083+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
1084+
low:(NSInteger)low
1085+
high:(NSInteger)high
1086+
NS_SWIFT_NAME(randint(like:low:high:));
1087+
1088+
@end
1089+
9831090
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,3 +833,85 @@ + (instancetype)randomTensorLikeTensor:(ExecuTorchTensor *)tensor {
833833
}
834834

835835
@end
836+
837+
@implementation ExecuTorchTensor (RandomInteger)
838+
839+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
840+
high:(NSInteger)high
841+
shape:(NSArray<NSNumber *> *)shape
842+
strides:(NSArray<NSNumber *> *)strides
843+
dataType:(ExecuTorchDataType)dataType
844+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
845+
auto tensor = randint_strided(
846+
low,
847+
high,
848+
utils::toVector<SizesType>(shape),
849+
utils::toVector<StridesType>(strides),
850+
static_cast<ScalarType>(dataType),
851+
static_cast<TensorShapeDynamism>(shapeDynamism)
852+
);
853+
return [[self alloc] initWithNativeInstance:&tensor];
854+
}
855+
856+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
857+
high:(NSInteger)high
858+
shape:(NSArray<NSNumber *> *)shape
859+
dataType:(ExecuTorchDataType)dataType
860+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
861+
return [self randomIntegerTensorWithLow:low
862+
high:high
863+
shape:shape
864+
strides:@[]
865+
dataType:dataType
866+
shapeDynamism:shapeDynamism];
867+
}
868+
869+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
870+
high:(NSInteger)high
871+
shape:(NSArray<NSNumber *> *)shape
872+
dataType:(ExecuTorchDataType)dataType {
873+
return [self randomIntegerTensorWithLow:low
874+
high:high
875+
shape:shape
876+
strides:@[]
877+
dataType:dataType
878+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
879+
}
880+
881+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
882+
low:(NSInteger)low
883+
high:(NSInteger)high
884+
dataType:(ExecuTorchDataType)dataType
885+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
886+
return [self randomIntegerTensorWithLow:low
887+
high:high
888+
shape:tensor.shape
889+
strides:tensor.strides
890+
dataType:dataType
891+
shapeDynamism:shapeDynamism];
892+
}
893+
894+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
895+
low:(NSInteger)low
896+
high:(NSInteger)high
897+
dataType:(ExecuTorchDataType)dataType {
898+
return [self randomIntegerTensorWithLow:low
899+
high:high
900+
shape:tensor.shape
901+
strides:tensor.strides
902+
dataType:dataType
903+
shapeDynamism:tensor.shapeDynamism];
904+
}
905+
906+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
907+
low:(NSInteger)low
908+
high:(NSInteger)high {
909+
return [self randomIntegerTensorWithLow:low
910+
high:high
911+
shape:tensor.shape
912+
strides:tensor.strides
913+
dataType:tensor.dataType
914+
shapeDynamism:tensor.shapeDynamism];
915+
}
916+
917+
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,4 +637,28 @@ class TensorTest: XCTestCase {
637637
XCTAssertEqual(tensor.shape, other.shape)
638638
XCTAssertEqual(tensor.count, other.count)
639639
}
640+
641+
func testRandomInteger() {
642+
let tensor = Tensor.randint(low: 10, high: 20, shape: [5], dataType: .int)
643+
XCTAssertEqual(tensor.shape, [5])
644+
XCTAssertEqual(tensor.count, 5)
645+
tensor.bytes { pointer, count, dataType in
646+
XCTAssertEqual(dataType, .int)
647+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count)
648+
for value in buffer {
649+
XCTAssertTrue(value >= 10 && value < 20)
650+
}
651+
}
652+
}
653+
654+
func testRandomIntegerLike() {
655+
let other = Tensor.ones(shape: [5], dataType: .int)
656+
let tensor = Tensor.randint(like: other, low: 100, high: 200)
657+
tensor.bytes { pointer, count, dataType in
658+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count)
659+
for value in buffer {
660+
XCTAssertTrue(value >= 100 && value < 200)
661+
}
662+
}
663+
}
640664
}

0 commit comments

Comments
 (0)