Skip to content

Commit deaaab5

Browse files
authored
Expose type-erased tensor from generic one. (#11962)
1 parent bbfcc2a commit deaaab5

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -821,9 +821,8 @@ public class Tensor<T: Scalar>: Equatable {
821821
lhs.anyTensor == rhs.anyTensor
822822
}
823823

824-
// MARK: Internal
825-
826-
let anyTensor: AnyTensor
824+
// Wrapped AnyTensor instance.
825+
public let anyTensor: AnyTensor
827826
}
828827

829828
@available(*, deprecated, message: "This API is experimental.")

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,66 @@ class TensorTest: XCTestCase {
678678
XCTAssertEqual(try tensor.scalars().first, 42)
679679
}
680680

681+
func testExtractAnyTensorMatchesOriginalDataAndMetadata() {
682+
let tensor = Tensor([1, 2, 3, 4], shape: [2, 2])
683+
let anyTensor = tensor.anyTensor
684+
XCTAssertEqual(anyTensor.shape, tensor.shape)
685+
XCTAssertEqual(anyTensor.strides, tensor.strides)
686+
XCTAssertEqual(anyTensor.dimensionOrder, tensor.dimensionOrder)
687+
XCTAssertEqual(anyTensor.count, tensor.count)
688+
XCTAssertEqual(anyTensor.dataType, tensor.dataType)
689+
XCTAssertEqual(anyTensor.shapeDynamism, tensor.shapeDynamism)
690+
let newTensor = Tensor<Int>(anyTensor)
691+
XCTAssertEqual(newTensor, tensor)
692+
}
693+
694+
func testReconstructGenericTensorViaInitAndAsTensor() {
695+
let tensor = Tensor([5, 6, 7])
696+
let anyTensor = tensor.anyTensor
697+
let tensorInit = Tensor<Int>(anyTensor)
698+
let tensorFromAny: Tensor<Int> = anyTensor.asTensor()!
699+
XCTAssertEqual(tensorInit, tensorFromAny)
700+
}
701+
702+
func testAsTensorMismatchedTypeReturnsNil() {
703+
let tensor = Tensor([8, 9, 10])
704+
let anyTensor = tensor.anyTensor
705+
let wrongTypedTensor: Tensor<Float>? = anyTensor.asTensor()
706+
XCTAssertNil(wrongTypedTensor)
707+
}
708+
709+
func testViewSharesDataAndResizeAltersShapeNotData() throws {
710+
var scalars = [11, 12, 13, 14]
711+
let tensor = Tensor(&scalars, shape: [2, 2])
712+
let viewTensor = Tensor(tensor)
713+
let scalarsAddress = scalars.withUnsafeBufferPointer { $0.baseAddress }
714+
let tensorDataAddress = try tensor.withUnsafeBytes { $0.baseAddress }
715+
let viewTensorDataAddress = try viewTensor.withUnsafeBytes { $0.baseAddress }
716+
XCTAssertEqual(tensorDataAddress, scalarsAddress)
717+
XCTAssertEqual(tensorDataAddress, viewTensorDataAddress)
718+
719+
scalars[2] = 42
720+
XCTAssertEqual(try tensor.scalars(), scalars)
721+
XCTAssertEqual(try viewTensor.scalars(), scalars)
722+
723+
XCTAssertNoThrow(try viewTensor.resize(to: [4, 1]))
724+
XCTAssertEqual(viewTensor.shape, [4, 1])
725+
XCTAssertEqual(tensor.shape, [2, 2])
726+
XCTAssertEqual(try tensor.scalars(), scalars)
727+
XCTAssertEqual(try viewTensor.scalars(), scalars)
728+
}
729+
730+
func testMultipleGenericFromAnyReflectChanges() {
731+
let tensor = Tensor([2, 4, 6, 8], shape: [2, 2])
732+
let anyTensor = tensor.anyTensor
733+
let tensor1: Tensor<Int> = anyTensor.asTensor()!
734+
let tensor2: Tensor<Int> = anyTensor.asTensor()!
735+
736+
XCTAssertEqual(tensor1, tensor2)
737+
XCTAssertNoThrow(try tensor1.withUnsafeMutableBytes { $0[1] = 42 })
738+
XCTAssertEqual(try tensor2.withUnsafeBytes { $0[1] }, 42)
739+
}
740+
681741
func testEmpty() {
682742
let tensor = Tensor<Float>.empty(shape: [3, 4])
683743
XCTAssertEqual(tensor.shape, [3, 4])

0 commit comments

Comments
 (0)