@@ -678,6 +678,66 @@ class TensorTest: XCTestCase {
678678 XCTAssertEqual ( try tensor. scalars ( ) . first, 42 )
679679 }
680680
681+ func testExtractAnyTensorMatchesOriginalDataAndMetadata( ) {
682+ let tensor = Tensor ( [ 1 , 2 , 3 , 4 ] , shape: [ 2 , 2 ] )
683+ let anyTensor = tensor. anyTensor
684+ XCTAssertEqual ( anyTensor. shape, tensor. shape)
685+ XCTAssertEqual ( anyTensor. strides, tensor. strides)
686+ XCTAssertEqual ( anyTensor. dimensionOrder, tensor. dimensionOrder)
687+ XCTAssertEqual ( anyTensor. count, tensor. count)
688+ XCTAssertEqual ( anyTensor. dataType, tensor. dataType)
689+ XCTAssertEqual ( anyTensor. shapeDynamism, tensor. shapeDynamism)
690+ let newTensor = Tensor < Int > ( anyTensor)
691+ XCTAssertEqual ( newTensor, tensor)
692+ }
693+
694+ func testReconstructGenericTensorViaInitAndAsTensor( ) {
695+ let tensor = Tensor ( [ 5 , 6 , 7 ] )
696+ let anyTensor = tensor. anyTensor
697+ let tensorInit = Tensor < Int > ( anyTensor)
698+ let tensorFromAny : Tensor < Int > = anyTensor. asTensor ( ) !
699+ XCTAssertEqual ( tensorInit, tensorFromAny)
700+ }
701+
702+ func testAsTensorMismatchedTypeReturnsNil( ) {
703+ let tensor = Tensor ( [ 8 , 9 , 10 ] )
704+ let anyTensor = tensor. anyTensor
705+ let wrongTypedTensor : Tensor < Float > ? = anyTensor. asTensor ( )
706+ XCTAssertNil ( wrongTypedTensor)
707+ }
708+
709+ func testViewSharesDataAndResizeAltersShapeNotData( ) throws {
710+ var scalars = [ 11 , 12 , 13 , 14 ]
711+ let tensor = Tensor ( & scalars, shape: [ 2 , 2 ] )
712+ let viewTensor = Tensor ( tensor)
713+ let scalarsAddress = scalars. withUnsafeBufferPointer { $0. baseAddress }
714+ let tensorDataAddress = try tensor. withUnsafeBytes { $0. baseAddress }
715+ let viewTensorDataAddress = try viewTensor. withUnsafeBytes { $0. baseAddress }
716+ XCTAssertEqual ( tensorDataAddress, scalarsAddress)
717+ XCTAssertEqual ( tensorDataAddress, viewTensorDataAddress)
718+
719+ scalars [ 2 ] = 42
720+ XCTAssertEqual ( try tensor. scalars ( ) , scalars)
721+ XCTAssertEqual ( try viewTensor. scalars ( ) , scalars)
722+
723+ XCTAssertNoThrow ( try viewTensor. resize ( to: [ 4 , 1 ] ) )
724+ XCTAssertEqual ( viewTensor. shape, [ 4 , 1 ] )
725+ XCTAssertEqual ( tensor. shape, [ 2 , 2 ] )
726+ XCTAssertEqual ( try tensor. scalars ( ) , scalars)
727+ XCTAssertEqual ( try viewTensor. scalars ( ) , scalars)
728+ }
729+
730+ func testMultipleGenericFromAnyReflectChanges( ) {
731+ let tensor = Tensor ( [ 2 , 4 , 6 , 8 ] , shape: [ 2 , 2 ] )
732+ let anyTensor = tensor. anyTensor
733+ let tensor1 : Tensor < Int > = anyTensor. asTensor ( ) !
734+ let tensor2 : Tensor < Int > = anyTensor. asTensor ( ) !
735+
736+ XCTAssertEqual ( tensor1, tensor2)
737+ XCTAssertNoThrow ( try tensor1. withUnsafeMutableBytes { $0 [ 1 ] = 42 } )
738+ XCTAssertEqual ( try tensor2. withUnsafeBytes { $0 [ 1 ] } , 42 )
739+ }
740+
681741 func testEmpty( ) {
682742 let tensor = Tensor< Float> . empty( shape: [ 3 , 4 ] )
683743 XCTAssertEqual ( tensor. shape, [ 3 , 4 ] )
0 commit comments