@@ -60,25 +60,45 @@ class TensorTest: XCTestCase {
6060 let tensor = data. withUnsafeMutableBytes {
6161 Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 3 ] , dataType: . float)
6262 }
63+ // Modify the original data to make sure the tensor does not copy the data.
64+ data. indices. forEach { data [ $0] += 1 }
65+
6366 XCTAssertEqual ( tensor. dataType, . float)
6467 XCTAssertEqual ( tensor. shape, [ 2 , 3 ] )
6568 XCTAssertEqual ( tensor. strides, [ 3 , 1 ] )
6669 XCTAssertEqual ( tensor. dimensionOrder, [ 0 , 1 ] )
6770 XCTAssertEqual ( tensor. shapeDynamism, . dynamicBound)
6871 XCTAssertEqual ( tensor. count, 6 )
72+
73+ tensor. bytes { pointer, count, dataType in
74+ XCTAssertEqual ( dataType, . float)
75+ XCTAssertEqual ( count, 6 )
76+ XCTAssertEqual ( size ( ofDataType: dataType) , 4 )
77+ XCTAssertEqual ( Array ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: Float . self) , count: count) ) , data)
78+ }
6979 }
7080
7181 func testInitBytes( ) {
7282 var data : [ Double ] = [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]
7383 let tensor = data. withUnsafeMutableBytes {
7484 Tensor ( bytes: $0. baseAddress!, shape: [ 2 , 3 ] , dataType: . double)
7585 }
86+ // Modify the original data to make sure the tensor copies the data.
87+ data. indices. forEach { data [ $0] += 1 }
88+
7689 XCTAssertEqual ( tensor. dataType, . double)
7790 XCTAssertEqual ( tensor. shape, [ 2 , 3 ] )
7891 XCTAssertEqual ( tensor. strides, [ 3 , 1 ] )
7992 XCTAssertEqual ( tensor. dimensionOrder, [ 0 , 1 ] )
8093 XCTAssertEqual ( tensor. shapeDynamism, . dynamicBound)
8194 XCTAssertEqual ( tensor. count, 6 )
95+
96+ tensor. bytes { pointer, count, dataType in
97+ XCTAssertEqual ( dataType, . double)
98+ XCTAssertEqual ( count, 6 )
99+ XCTAssertEqual ( size ( ofDataType: dataType) , 8 )
100+ XCTAssertEqual ( Array ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: Double . self) , count: count) ) . map { $0 + 1 } , data)
101+ }
82102 }
83103
84104 func testWithCustomStridesAndDimensionOrder( ) {
@@ -94,5 +114,27 @@ class TensorTest: XCTestCase {
94114 XCTAssertEqual ( tensor. strides, [ 1 , 2 ] )
95115 XCTAssertEqual ( tensor. dimensionOrder, [ 1 , 0 ] )
96116 XCTAssertEqual ( tensor. count, 4 )
117+
118+ tensor. bytes { pointer, count, dataType in
119+ XCTAssertEqual ( Array ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: Float . self) , count: count) ) , data)
120+ }
121+ }
122+
123+ func testMutableBytes( ) {
124+ var data : [ Int32 ] = [ 1 , 2 , 3 , 4 ]
125+ let tensor = data. withUnsafeMutableBytes {
126+ Tensor ( bytes: $0. baseAddress!, shape: [ 4 ] , dataType: . int)
127+ }
128+ tensor. mutableBytes { pointer, count, dataType in
129+ XCTAssertEqual ( dataType, . int)
130+ let buffer = pointer. assumingMemoryBound ( to: Int32 . self)
131+ for i in 0 ..< count {
132+ buffer [ i] *= 2
133+ }
134+ }
135+ tensor. bytes { pointer, count, dataType in
136+ let updatedData = Array ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: Int32 . self) , count: count) )
137+ XCTAssertEqual ( updatedData, [ 2 , 4 , 6 , 8 ] )
138+ }
97139 }
98140}
0 commit comments