Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,28 @@ __attribute__((deprecated("This API is experimental.")))

@end

#pragma mark - Data Category

@interface ExecuTorchTensor (Data)

/**
* Initializes a tensor using an NSData object as the underlying data buffer.
*
* @param data An NSData object containing the tensor data.
* @param shape An NSArray of NSNumber objects representing the tensor's shape.
* @param strides An NSArray of NSNumber objects representing the tensor's strides.
* @param dimensionOrder An NSArray of NSNumber objects indicating the order of dimensions.
* @param dataType An ExecuTorchDataType value specifying the element type.
* @param shapeDynamism An ExecuTorchShapeDynamism value indicating the shape dynamism.
* @return An initialized ExecuTorchTensor instance using the provided data.
*/
- (instancetype)initWithData:(NSData *)data
shape:(NSArray<NSNumber *> *)shape
strides:(NSArray<NSNumber *> *)strides
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
dataType:(ExecuTorchDataType)dataType
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism;

@end

NS_ASSUME_NONNULL_END
25 changes: 25 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ NSInteger ExecuTorchElementCountOfShape(NSArray<NSNumber *> *shape) {

@implementation ExecuTorchTensor {
TensorPtr _tensor;
NSData *_data;
NSArray<NSNumber *> *_shape;
NSArray<NSNumber *> *_strides;
NSArray<NSNumber *> *_dimensionOrder;
Expand Down Expand Up @@ -274,3 +275,27 @@ - (instancetype)initWithBytes:(const void *)pointer
}

@end

@implementation ExecuTorchTensor (Data)

- (instancetype)initWithData:(NSData *)data
shape:(NSArray<NSNumber *> *)shape
strides:(NSArray<NSNumber *> *)strides
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
dataType:(ExecuTorchDataType)dataType
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
ET_CHECK_MSG(data.length >= ExecuTorchElementCountOfShape(shape) * ExecuTorchSizeOfDataType(dataType),
"Data length is too small");
self = [self initWithBytesNoCopy:(void *)data.bytes
shape:shape
strides:strides
dimensionOrder:dimensionOrder
dataType:dataType
shapeDynamism:shapeDynamism];
if (self) {
_data = data;
}
return self;
}

@end
10 changes: 10 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ class TensorTest: XCTestCase {
}
}

func testInitData() {
let dataArray: [Float] = [1.0, 2.0, 3.0, 4.0]
let data = Data(bytes: dataArray, count: dataArray.count * MemoryLayout<Float>.size)
let tensor = Tensor(data: data, shape: [4], strides: [1], dimensionOrder: [0], dataType: .float, shapeDynamism: .static)
XCTAssertEqual(tensor.count, 4)
tensor.bytes { pointer, count, dataType in
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), dataArray)
}
}

func testWithCustomStridesAndDimensionOrder() {
let data: [Float] = [1.0, 2.0, 3.0, 4.0]
let tensor = Tensor(
Expand Down
Loading