diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index a7e5004156d..07b43ef1818 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -345,4 +345,28 @@ __attribute__((deprecated("This API is experimental."))) @end +#pragma mark - Data Category + +@interface ExecuTorchTensor (Data) + +/** + * Initializes a tensor using an NSData object as the underlying data buffer. + * + * @param data An NSData object containing the tensor data. + * @param shape An NSArray of NSNumber objects representing the tensor's shape. + * @param strides An NSArray of NSNumber objects representing the tensor's strides. + * @param dimensionOrder An NSArray of NSNumber objects indicating the order of dimensions. + * @param dataType An ExecuTorchDataType value specifying the element type. + * @param shapeDynamism An ExecuTorchShapeDynamism value indicating the shape dynamism. + * @return An initialized ExecuTorchTensor instance using the provided data. + */ +- (instancetype)initWithData:(NSData *)data + shape:(NSArray *)shape + strides:(NSArray *)strides + dimensionOrder:(NSArray *)dimensionOrder + dataType:(ExecuTorchDataType)dataType + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism; + +@end + NS_ASSUME_NONNULL_END diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index b18e81fa8af..fd04359df46 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -31,6 +31,7 @@ NSInteger ExecuTorchElementCountOfShape(NSArray *shape) { @implementation ExecuTorchTensor { TensorPtr _tensor; + NSData *_data; NSArray *_shape; NSArray *_strides; NSArray *_dimensionOrder; @@ -274,3 +275,27 @@ - (instancetype)initWithBytes:(const void *)pointer } @end + +@implementation ExecuTorchTensor (Data) + +- (instancetype)initWithData:(NSData *)data + shape:(NSArray *)shape + strides:(NSArray *)strides + dimensionOrder:(NSArray *)dimensionOrder + dataType:(ExecuTorchDataType)dataType + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism { + ET_CHECK_MSG(data.length >= ExecuTorchElementCountOfShape(shape) * ExecuTorchSizeOfDataType(dataType), + "Data length is too small"); + self = [self initWithBytesNoCopy:(void *)data.bytes + shape:shape + strides:strides + dimensionOrder:dimensionOrder + dataType:dataType + shapeDynamism:shapeDynamism]; + if (self) { + _data = data; + } + return self; +} + +@end diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index f4bdc9927ae..32dde3db7bc 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -101,6 +101,16 @@ class TensorTest: XCTestCase { } } + func testInitData() { + let dataArray: [Float] = [1.0, 2.0, 3.0, 4.0] + let data = Data(bytes: dataArray, count: dataArray.count * MemoryLayout.size) + let tensor = Tensor(data: data, shape: [4], strides: [1], dimensionOrder: [0], dataType: .float, shapeDynamism: .static) + XCTAssertEqual(tensor.count, 4) + tensor.bytes { pointer, count, dataType in + XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), dataArray) + } + } + func testWithCustomStridesAndDimensionOrder() { let data: [Float] = [1.0, 2.0, 3.0, 4.0] let tensor = Tensor(