Skip to content

Commit 6c698f2

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Allow data casting along with cloning. (#15511)
Summary: . Differential Revision: D86070965
1 parent f832c65 commit 6c698f2

File tree

4 files changed

+60
-0
lines changed

4 files changed

+60
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,14 @@ public final class Tensor<T: Scalar>: Equatable {
741741
Tensor<T>(anyTensor.copy())
742742
}
743743

744+
/// Returns a copy of the tensor, converted to the specified scalar type.
745+
///
746+
/// - Parameter dataType: The target scalar type.
747+
/// - Returns: A new tensor with the same shape and metadata but converted elements.
748+
public func copy<U: Scalar>(to dataType: U.Type) -> Tensor<U> {
749+
Tensor<U>(anyTensor.copy(to: U.dataType))
750+
}
751+
744752
/// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements.
745753
///
746754
/// - Parameter body: A closure that receives an `UnsafeBufferPointer<T>` bound to the tensor’s data.

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ __attribute__((objc_subclassing_restricted))
168168
*/
169169
- (instancetype)copy;
170170

171+
/**
172+
 * Creates a deep copy of the tensor, potentially casting to a new data type.
173+
 * The new tensor will have its own copy of the data.
174+
 *
175+
 * @param dataType The desired data type for the new tensor.
176+
 * @return A new ExecuTorchTensor instance that is a duplicate (and possibly casted) of the current tensor.
177+
*/
178+
- (instancetype)copyToDataType:(ExecuTorchDataType)dataType
179+
NS_SWIFT_NAME(copy(to:));
180+
171181
/**
172182
* Executes a block with a pointer to the tensor's immutable byte data.
173183
*

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ - (instancetype)copyWithZone:(nullable NSZone *)zone {
147147
return [[ExecuTorchTensor allocWithZone:zone] initWithNativeInstance:&tensor];
148148
}
149149

150+
- (instancetype)copyToDataType:(ExecuTorchDataType)dataType {
151+
auto tensor = clone_tensor_ptr(_tensor, static_cast<ScalarType>(dataType));
152+
return [[ExecuTorchTensor alloc] initWithNativeInstance:&tensor];
153+
}
154+
150155
- (void *)nativeInstance {
151156
return &_tensor;
152157
}

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,43 @@ class TensorTest: XCTestCase {
182182
XCTAssertEqual(tensor1.count, tensor2.count)
183183
}
184184

185+
func testCopyToSameDataType() {
186+
let tensor1 = Tensor<Float>([1, 2, 3, 4], shape: [2, 2])
187+
let tensor2 = tensor1.copy(to: Float.self)
188+
XCTAssertEqual(tensor2.dataType, .float)
189+
XCTAssertEqual(tensor2.shape, [2, 2])
190+
XCTAssertEqual(tensor2.strides, tensor1.strides)
191+
XCTAssertEqual(tensor2.dimensionOrder, tensor1.dimensionOrder)
192+
XCTAssertEqual(tensor2.scalars(), [1, 2, 3, 4])
193+
}
194+
195+
func testCopyToDifferentDataTypeKeepsSourceAlive() {
196+
var data = [10.0, 20.0, 30.0, 40.0]
197+
let tensor1 = data.withUnsafeMutableBytes {
198+
Tensor<Double>(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
199+
}
200+
let tensor2 = tensor1.copy(to: Float.self)
201+
data[0] = 999.0
202+
XCTAssertEqual(tensor2.dataType, .float)
203+
XCTAssertEqual(tensor2.shape, [2, 2])
204+
XCTAssertEqual(tensor2.scalars(), [10.0, 20.0, 30.0, 40.0])
205+
}
206+
207+
func testCopyToPreservesShapeAndOrderOn2D() {
208+
let tensor1 = Tensor<Int32>(
209+
[1, 2, 3, 4, 5, 6],
210+
shape: [2, 3],
211+
strides: [3, 1],
212+
dimensionOrder: [0, 1]
213+
)
214+
let tensor2 = tensor1.copy(to: Double.self)
215+
XCTAssertEqual(tensor2.shape, [2, 3])
216+
XCTAssertEqual(tensor2.strides, [3, 1])
217+
XCTAssertEqual(tensor2.dimensionOrder, [0, 1])
218+
XCTAssertEqual(tensor2.count, 6)
219+
XCTAssertEqual(tensor2.scalars(), [1, 2, 3, 4, 5, 6])
220+
}
221+
185222
func testResize() {
186223
var data: [Int] = [1, 2, 3, 4]
187224
let tensor = data.withUnsafeMutableBytes {

0 commit comments

Comments
 (0)