@@ -678,6 +678,66 @@ class TensorTest: XCTestCase {
678
678
XCTAssertEqual ( try tensor. scalars ( ) . first, 42 )
679
679
}
680
680
681
+ func testExtractAnyTensorMatchesOriginalDataAndMetadata( ) {
682
+ let tensor = Tensor ( [ 1 , 2 , 3 , 4 ] , shape: [ 2 , 2 ] )
683
+ let anyTensor = tensor. anyTensor
684
+ XCTAssertEqual ( anyTensor. shape, tensor. shape)
685
+ XCTAssertEqual ( anyTensor. strides, tensor. strides)
686
+ XCTAssertEqual ( anyTensor. dimensionOrder, tensor. dimensionOrder)
687
+ XCTAssertEqual ( anyTensor. count, tensor. count)
688
+ XCTAssertEqual ( anyTensor. dataType, tensor. dataType)
689
+ XCTAssertEqual ( anyTensor. shapeDynamism, tensor. shapeDynamism)
690
+ let newTensor = Tensor < Int > ( anyTensor)
691
+ XCTAssertEqual ( newTensor, tensor)
692
+ }
693
+
694
+ func testReconstructGenericTensorViaInitAndAsTensor( ) {
695
+ let tensor = Tensor ( [ 5 , 6 , 7 ] )
696
+ let anyTensor = tensor. anyTensor
697
+ let tensorInit = Tensor < Int > ( anyTensor)
698
+ let tensorFromAny : Tensor < Int > = anyTensor. asTensor ( ) !
699
+ XCTAssertEqual ( tensorInit, tensorFromAny)
700
+ }
701
+
702
+ func testAsTensorMismatchedTypeReturnsNil( ) {
703
+ let tensor = Tensor ( [ 8 , 9 , 10 ] )
704
+ let anyTensor = tensor. anyTensor
705
+ let wrongTypedTensor : Tensor < Float > ? = anyTensor. asTensor ( )
706
+ XCTAssertNil ( wrongTypedTensor)
707
+ }
708
+
709
+ func testViewSharesDataAndResizeAltersShapeNotData( ) throws {
710
+ var scalars = [ 11 , 12 , 13 , 14 ]
711
+ let tensor = Tensor ( & scalars, shape: [ 2 , 2 ] )
712
+ let viewTensor = Tensor ( tensor)
713
+ let scalarsAddress = scalars. withUnsafeBufferPointer { $0. baseAddress }
714
+ let tensorDataAddress = try tensor. withUnsafeBytes { $0. baseAddress }
715
+ let viewTensorDataAddress = try viewTensor. withUnsafeBytes { $0. baseAddress }
716
+ XCTAssertEqual ( tensorDataAddress, scalarsAddress)
717
+ XCTAssertEqual ( tensorDataAddress, viewTensorDataAddress)
718
+
719
+ scalars [ 2 ] = 42
720
+ XCTAssertEqual ( try tensor. scalars ( ) , scalars)
721
+ XCTAssertEqual ( try viewTensor. scalars ( ) , scalars)
722
+
723
+ XCTAssertNoThrow ( try viewTensor. resize ( to: [ 4 , 1 ] ) )
724
+ XCTAssertEqual ( viewTensor. shape, [ 4 , 1 ] )
725
+ XCTAssertEqual ( tensor. shape, [ 2 , 2 ] )
726
+ XCTAssertEqual ( try tensor. scalars ( ) , scalars)
727
+ XCTAssertEqual ( try viewTensor. scalars ( ) , scalars)
728
+ }
729
+
730
+ func testMultipleGenericFromAnyReflectChanges( ) {
731
+ let tensor = Tensor ( [ 2 , 4 , 6 , 8 ] , shape: [ 2 , 2 ] )
732
+ let anyTensor = tensor. anyTensor
733
+ let tensor1 : Tensor < Int > = anyTensor. asTensor ( ) !
734
+ let tensor2 : Tensor < Int > = anyTensor. asTensor ( ) !
735
+
736
+ XCTAssertEqual ( tensor1, tensor2)
737
+ XCTAssertNoThrow ( try tensor1. withUnsafeMutableBytes { $0 [ 1 ] = 42 } )
738
+ XCTAssertEqual ( try tensor2. withUnsafeBytes { $0 [ 1 ] } , 42 )
739
+ }
740
+
681
741
func testEmpty( ) {
682
742
let tensor = Tensor< Float> . empty( shape: [ 3 , 4 ] )
683
743
XCTAssertEqual ( tensor. shape, [ 3 , 4 ] )
0 commit comments