Skip to content

Commit 2865652

Browse files
authored
Tensor accessorts to get raw data buffer.
Differential Revision: D71905971 Pull Request resolved: #9676
1 parent dab0d3c commit 2865652

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,28 @@ __attribute__((deprecated("This API is experimental.")))
151151
- (instancetype)initWithNativeInstance:(void *)nativeInstance
152152
NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE("");
153153

154+
/**
155+
* Executes a block with a pointer to the tensor's immutable byte data.
156+
*
157+
* @param handler A block that receives:
158+
* - a pointer to the data,
159+
* - the total number of elements,
160+
* - and the data type.
161+
*/
162+
- (void)bytesWithHandler:(void (^)(const void *pointer, NSInteger count, ExecuTorchDataType dataType))handler
163+
NS_SWIFT_NAME(bytes(_:));
164+
165+
/**
166+
* Executes a block with a pointer to the tensor's mutable byte data.
167+
*
168+
* @param handler A block that receives:
169+
* - a mutable pointer to the data,
170+
* - the total number of elements,
171+
* - and the data type.
172+
*/
173+
- (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuTorchDataType dataType))handler
174+
NS_SWIFT_NAME(mutableBytes(_:));
175+
154176
+ (instancetype)new NS_UNAVAILABLE;
155177
- (instancetype)init NS_UNAVAILABLE;
156178

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ - (NSInteger)count {
8181
return _tensor->numel();
8282
}
8383

84+
- (void)bytesWithHandler:(void (^)(const void *pointer, NSInteger count, ExecuTorchDataType type))handler {
85+
ET_CHECK(handler);
86+
handler(_tensor->unsafeGetTensorImpl()->data(), self.count, self.dataType);
87+
}
88+
89+
- (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuTorchDataType dataType))handler {
90+
ET_CHECK(handler);
91+
handler(_tensor->unsafeGetTensorImpl()->mutable_data(), self.count, self.dataType);
92+
}
93+
8494
@end
8595

8696
@implementation ExecuTorchTensor (BytesNoCopy)

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,25 +60,45 @@ class TensorTest: XCTestCase {
6060
let tensor = data.withUnsafeMutableBytes {
6161
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float)
6262
}
63+
// Modify the original data to make sure the tensor does not copy the data.
64+
data.indices.forEach { data[$0] += 1 }
65+
6366
XCTAssertEqual(tensor.dataType, .float)
6467
XCTAssertEqual(tensor.shape, [2, 3])
6568
XCTAssertEqual(tensor.strides, [3, 1])
6669
XCTAssertEqual(tensor.dimensionOrder, [0, 1])
6770
XCTAssertEqual(tensor.shapeDynamism, .dynamicBound)
6871
XCTAssertEqual(tensor.count, 6)
72+
73+
tensor.bytes { pointer, count, dataType in
74+
XCTAssertEqual(dataType, .float)
75+
XCTAssertEqual(count, 6)
76+
XCTAssertEqual(size(ofDataType: dataType), 4)
77+
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), data)
78+
}
6979
}
7080

7181
func testInitBytes() {
7282
var data: [Double] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
7383
let tensor = data.withUnsafeMutableBytes {
7484
Tensor(bytes: $0.baseAddress!, shape: [2, 3], dataType: .double)
7585
}
86+
// Modify the original data to make sure the tensor copies the data.
87+
data.indices.forEach { data[$0] += 1 }
88+
7689
XCTAssertEqual(tensor.dataType, .double)
7790
XCTAssertEqual(tensor.shape, [2, 3])
7891
XCTAssertEqual(tensor.strides, [3, 1])
7992
XCTAssertEqual(tensor.dimensionOrder, [0, 1])
8093
XCTAssertEqual(tensor.shapeDynamism, .dynamicBound)
8194
XCTAssertEqual(tensor.count, 6)
95+
96+
tensor.bytes { pointer, count, dataType in
97+
XCTAssertEqual(dataType, .double)
98+
XCTAssertEqual(count, 6)
99+
XCTAssertEqual(size(ofDataType: dataType), 8)
100+
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Double.self), count: count)).map { $0 + 1 }, data)
101+
}
82102
}
83103

84104
func testWithCustomStridesAndDimensionOrder() {
@@ -94,5 +114,27 @@ class TensorTest: XCTestCase {
94114
XCTAssertEqual(tensor.strides, [1, 2])
95115
XCTAssertEqual(tensor.dimensionOrder, [1, 0])
96116
XCTAssertEqual(tensor.count, 4)
117+
118+
tensor.bytes { pointer, count, dataType in
119+
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), data)
120+
}
121+
}
122+
123+
func testMutableBytes() {
124+
var data: [Int32] = [1, 2, 3, 4]
125+
let tensor = data.withUnsafeMutableBytes {
126+
Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .int)
127+
}
128+
tensor.mutableBytes { pointer, count, dataType in
129+
XCTAssertEqual(dataType, .int)
130+
let buffer = pointer.assumingMemoryBound(to: Int32.self)
131+
for i in 0..<count {
132+
buffer[i] *= 2
133+
}
134+
}
135+
tensor.bytes { pointer, count, dataType in
136+
let updatedData = Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count))
137+
XCTAssertEqual(updatedData, [2, 4, 6, 8])
138+
}
97139
}
98140
}

0 commit comments

Comments
 (0)