Skip to content

Commit 0b1a840

Browse files
authored
Tensor constructor to create with an existing Tensor.
Differential Revision: D71906972 Pull Request resolved: #9677
1 parent 2865652 commit 0b1a840

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ __attribute__((deprecated("This API is experimental.")))
151151
- (instancetype)initWithNativeInstance:(void *)nativeInstance
152152
NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE("");
153153

154+
/**
155+
* Creates a new tensor by copying an existing tensor.
156+
*
157+
* @param otherTensor The tensor instance to copy.
158+
* @return A new ExecuTorchTensor instance that is a copy of otherTensor.
159+
*/
160+
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
161+
NS_SWIFT_NAME(init(_:));
162+
154163
/**
155164
* Executes a block with a pointer to the tensor's immutable byte data.
156165
*

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ - (instancetype)initWithNativeInstance:(void *)nativeInstance {
4444
return self;
4545
}
4646

47+
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
48+
ET_CHECK(otherTensor);
49+
auto tensor = make_tensor_ptr(
50+
**reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance)
51+
);
52+
return [self initWithNativeInstance:&tensor];
53+
}
54+
4755
- (void *)nativeInstance {
4856
return &_tensor;
4957
}

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,18 @@ class TensorTest: XCTestCase {
137137
XCTAssertEqual(updatedData, [2, 4, 6, 8])
138138
}
139139
}
140+
141+
func testInitWithTensor() {
142+
var data: [Int] = [10, 20, 30, 40]
143+
let tensor1 = data.withUnsafeMutableBytes {
144+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .int)
145+
}
146+
let tensor2 = Tensor(tensor1)
147+
148+
XCTAssertEqual(tensor2.dataType, tensor1.dataType)
149+
XCTAssertEqual(tensor2.shape, tensor1.shape)
150+
XCTAssertEqual(tensor2.strides, tensor1.strides)
151+
XCTAssertEqual(tensor2.dimensionOrder, tensor1.dimensionOrder)
152+
XCTAssertEqual(tensor2.count, tensor1.count)
153+
}
140154
}

0 commit comments

Comments
 (0)