Skip to content

Commit dab0d3c

Browse files
authored
Overloads for the bytes Tensor constrcutor.
Differential Revision: D71905631 Pull Request resolved: #9675
1 parent 7030b33 commit dab0d3c

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,48 @@ __attribute__((deprecated("This API is experimental.")))
244244
dataType:(ExecuTorchDataType)dataType
245245
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism;
246246

247+
/**
248+
* Initializes a tensor by copying bytes from the provided pointer with dynamic bound shape.
249+
*
250+
* @param pointer A pointer to the source data buffer.
251+
* @param shape An NSArray of NSNumber objects representing the tensor's shape.
252+
* @param strides An NSArray of NSNumber objects representing the tensor's strides.
253+
* @param dimensionOrder An NSArray of NSNumber objects indicating the order of dimensions.
254+
* @param dataType An ExecuTorchDataType value specifying the element type.
255+
* @return An initialized ExecuTorchTensor instance with its own copy of the data.
256+
*/
257+
- (instancetype)initWithBytes:(const void *)pointer
258+
shape:(NSArray<NSNumber *> *)shape
259+
strides:(NSArray<NSNumber *> *)strides
260+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
261+
dataType:(ExecuTorchDataType)dataType;
262+
263+
/**
264+
* Initializes a tensor by copying bytes from the provided pointer, specifying shape, data type, and explicit shape dynamism.
265+
*
266+
* @param pointer A pointer to the source data buffer.
267+
* @param shape An NSArray of NSNumber objects representing the tensor's shape.
268+
* @param dataType An ExecuTorchDataType value specifying the element type.
269+
* @param shapeDynamism An ExecuTorchShapeDynamism value indicating the shape dynamism.
270+
* @return An initialized ExecuTorchTensor instance with its own copy of the data.
271+
*/
272+
- (instancetype)initWithBytes:(const void *)pointer
273+
shape:(NSArray<NSNumber *> *)shape
274+
dataType:(ExecuTorchDataType)dataType
275+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism;
276+
277+
/**
278+
* Initializes a tensor by copying bytes from the provided pointer, specifying only the shape and data type.
279+
*
280+
* @param pointer A pointer to the source data buffer.
281+
* @param shape An NSArray of NSNumber objects representing the tensor's shape.
282+
* @param dataType An ExecuTorchDataType value specifying the element type.
283+
* @return An initialized ExecuTorchTensor instance with its own copy of the data.
284+
*/
285+
- (instancetype)initWithBytes:(const void *)pointer
286+
shape:(NSArray<NSNumber *> *)shape
287+
dataType:(ExecuTorchDataType)dataType;
288+
247289
@end
248290

249291
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,40 @@ - (instancetype)initWithBytes:(const void *)pointer
164164
return [self initWithNativeInstance:&tensor];
165165
}
166166

167+
- (instancetype)initWithBytes:(const void *)pointer
168+
shape:(NSArray<NSNumber *> *)shape
169+
strides:(NSArray<NSNumber *> *)strides
170+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
171+
dataType:(ExecuTorchDataType)dataType {
172+
return [self initWithBytes:pointer
173+
shape:shape
174+
strides:strides
175+
dimensionOrder:dimensionOrder
176+
dataType:dataType
177+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
178+
}
179+
180+
- (instancetype)initWithBytes:(const void *)pointer
181+
shape:(NSArray<NSNumber *> *)shape
182+
dataType:(ExecuTorchDataType)dataType
183+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
184+
return [self initWithBytes:pointer
185+
shape:shape
186+
strides:@[]
187+
dimensionOrder:@[]
188+
dataType:dataType
189+
shapeDynamism:shapeDynamism];
190+
}
191+
192+
- (instancetype)initWithBytes:(const void *)pointer
193+
shape:(NSArray<NSNumber *> *)shape
194+
dataType:(ExecuTorchDataType)dataType {
195+
return [self initWithBytes:pointer
196+
shape:shape
197+
strides:@[]
198+
dimensionOrder:@[]
199+
dataType:dataType
200+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
201+
}
202+
167203
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,32 @@ class TensorTest: XCTestCase {
6767
XCTAssertEqual(tensor.shapeDynamism, .dynamicBound)
6868
XCTAssertEqual(tensor.count, 6)
6969
}
70+
71+
func testInitBytes() {
72+
var data: [Double] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
73+
let tensor = data.withUnsafeMutableBytes {
74+
Tensor(bytes: $0.baseAddress!, shape: [2, 3], dataType: .double)
75+
}
76+
XCTAssertEqual(tensor.dataType, .double)
77+
XCTAssertEqual(tensor.shape, [2, 3])
78+
XCTAssertEqual(tensor.strides, [3, 1])
79+
XCTAssertEqual(tensor.dimensionOrder, [0, 1])
80+
XCTAssertEqual(tensor.shapeDynamism, .dynamicBound)
81+
XCTAssertEqual(tensor.count, 6)
82+
}
83+
84+
func testWithCustomStridesAndDimensionOrder() {
85+
let data: [Float] = [1.0, 2.0, 3.0, 4.0]
86+
let tensor = Tensor(
87+
bytes: data.withUnsafeBytes { $0.baseAddress! },
88+
shape: [2, 2],
89+
strides: [1, 2],
90+
dimensionOrder: [1, 0],
91+
dataType: .float
92+
)
93+
XCTAssertEqual(tensor.shape, [2, 2])
94+
XCTAssertEqual(tensor.strides, [1, 2])
95+
XCTAssertEqual(tensor.dimensionOrder, [1, 0])
96+
XCTAssertEqual(tensor.count, 4)
97+
}
7098
}

0 commit comments

Comments
 (0)