@@ -56,8 +56,8 @@ class TensorTest: XCTestCase {
5656
5757 func testInitBytesNoCopy( ) {
5858 var data : [ Float ] = [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]
59- let tensor : Tensor < Float > = data. withUnsafeMutableBytes {
60- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 3 ] )
59+ let tensor = data. withUnsafeMutableBytes {
60+ Tensor < Float > ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 3 ] )
6161 }
6262 // Modify the original data to make sure the tensor does not copy the data.
6363 data. indices. forEach { data [ $0] += 1 }
@@ -73,8 +73,8 @@ class TensorTest: XCTestCase {
7373
7474 func testInitBytes( ) {
7575 var data : [ Double ] = [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]
76- let tensor : Tensor < Double > = data. withUnsafeMutableBytes {
77- Tensor ( bytes: $0. baseAddress!, shape: [ 2 , 3 ] )
76+ let tensor = data. withUnsafeMutableBytes {
77+ Tensor < Double > ( bytes: $0. baseAddress!, shape: [ 2 , 3 ] )
7878 }
7979 // Modify the original data to make sure the tensor copies the data.
8080 data. indices. forEach { data [ $0] += 1 }
@@ -91,14 +91,14 @@ class TensorTest: XCTestCase {
9191 func testInitData( ) {
9292 let dataArray : [ Float ] = [ 1.0 , 2.0 , 3.0 , 4.0 ]
9393 let data = Data ( bytes: dataArray, count: dataArray. count * MemoryLayout< Float> . size)
94- let tensor : Tensor < Float > = Tensor ( data: data, shape: [ 4 ] )
94+ let tensor = Tensor < Float > ( data: data, shape: [ 4 ] )
9595 XCTAssertEqual ( tensor. count, 4 )
9696 XCTAssertEqual ( try tensor. scalars ( ) , dataArray)
9797 }
9898
9999 func testWithCustomStridesAndDimensionOrder( ) {
100100 let data : [ Float ] = [ 1.0 , 2.0 , 3.0 , 4.0 ]
101- let tensor : Tensor < Float > = Tensor (
101+ let tensor = Tensor < Float > (
102102 bytes: data. withUnsafeBytes { $0. baseAddress! } ,
103103 shape: [ 2 , 2 ] ,
104104 strides: [ 1 , 2 ] ,
@@ -113,8 +113,8 @@ class TensorTest: XCTestCase {
113113
114114 func testMutableBytes( ) {
115115 var data : [ Int32 ] = [ 1 , 2 , 3 , 4 ]
116- let tensor : Tensor < Int32 > = data. withUnsafeMutableBytes {
117- Tensor ( bytes: $0. baseAddress!, shape: [ 4 ] )
116+ let tensor = data. withUnsafeMutableBytes {
117+ Tensor < Int32 > ( bytes: $0. baseAddress!, shape: [ 4 ] )
118118 }
119119 XCTAssertNoThrow ( try tensor. withUnsafeMutableBytes { buffer in
120120 for i in buffer. indices {
@@ -126,8 +126,8 @@ class TensorTest: XCTestCase {
126126
127127 func testInitWithTensor( ) throws {
128128 var data : [ Int ] = [ 10 , 20 , 30 , 40 ]
129- let tensor1 : Tensor < Int > = data. withUnsafeMutableBytes {
130- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
129+ let tensor1 = data. withUnsafeMutableBytes {
130+ Tensor < Int > ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
131131 }
132132 let tensor2 = Tensor ( tensor1)
133133
@@ -157,8 +157,8 @@ class TensorTest: XCTestCase {
157157
158158 func testCopy( ) {
159159 var data : [ Double ] = [ 10.0 , 20.0 , 30.0 , 40.0 ]
160- let tensor1 : Tensor < Double > = data. withUnsafeMutableBytes {
161- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
160+ let tensor1 = data. withUnsafeMutableBytes {
161+ Tensor < Double > ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
162162 }
163163 let tensor2 = tensor1. copy ( )
164164
@@ -171,8 +171,8 @@ class TensorTest: XCTestCase {
171171
172172 func testResize( ) {
173173 var data : [ Int ] = [ 1 , 2 , 3 , 4 ]
174- let tensor : Tensor < Int > = data. withUnsafeMutableBytes {
175- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 4 , 1 ] )
174+ let tensor = data. withUnsafeMutableBytes {
175+ Tensor < Int > ( bytesNoCopy: $0. baseAddress!, shape: [ 4 , 1 ] )
176176 }
177177 XCTAssertNoThrow ( try tensor. resize ( to: [ 2 , 2 ] ) )
178178 XCTAssertEqual ( tensor. dataType, . long)
@@ -185,38 +185,84 @@ class TensorTest: XCTestCase {
185185
186186 func testResizeError( ) {
187187 var data : [ Int ] = [ 1 , 2 , 3 , 4 ]
188- let tensor : Tensor < Int > = data. withUnsafeMutableBytes {
189- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 4 , 1 ] )
188+ let tensor = data. withUnsafeMutableBytes {
189+ Tensor < Int > ( bytesNoCopy: $0. baseAddress!, shape: [ 4 , 1 ] )
190190 }
191191 XCTAssertThrowsError ( try tensor. resize ( to: [ 2 , 3 ] ) )
192192 }
193193
194194 func testIsEqual( ) {
195195 var data : [ Float ] = [ 1.0 , 2.0 , 3.0 , 4.0 ]
196- let tensor1 : Tensor < Float > = data. withUnsafeMutableBytes {
197- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
196+ let tensor1 = data. withUnsafeMutableBytes {
197+ Tensor < Float > ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
198198 }
199199 let tensor2 = Tensor ( tensor1)
200200 XCTAssertEqual ( tensor1, tensor2)
201201 XCTAssertEqual ( tensor2, tensor1)
202202
203203 var dataModified : [ Float ] = [ 1.0 , 2.0 , 3.0 , 5.0 ]
204- let tensor3 : Tensor < Float > = dataModified. withUnsafeMutableBytes {
205- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
204+ let tensor3 = dataModified. withUnsafeMutableBytes {
205+ Tensor < Float > ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] )
206206 }
207207 XCTAssertNotEqual ( tensor1, tensor3)
208- let tensor4 : Tensor < Float > = data. withUnsafeMutableBytes {
209- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 4 , 1 ] )
208+ let tensor4 = data. withUnsafeMutableBytes {
209+ Tensor < Float > ( bytesNoCopy: $0. baseAddress!, shape: [ 4 , 1 ] )
210210 }
211211 XCTAssertNotEqual ( tensor1, tensor4)
212212 XCTAssertEqual ( tensor1, tensor1)
213213 XCTAssertNotEqual ( tensor4, tensor2)
214- let tensor5 : Tensor < Float > = data. withUnsafeMutableBytes {
215- Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] , shapeDynamism: . static)
214+ let tensor5 = data. withUnsafeMutableBytes {
215+ Tensor < Float > ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 2 ] , shapeDynamism: . static)
216216 }
217217 XCTAssertEqual ( tensor1, tensor5)
218218 }
219219
220+ func testInitScalarsNoCopyDefaultShape( ) throws {
221+ var data : [ Float ] = [ 1.0 , 2.0 , 3.0 , 4.0 ]
222+ let tensor = Tensor ( & data)
223+
224+ XCTAssertEqual ( tensor. dataType, . float)
225+ XCTAssertEqual ( tensor. shape, [ 4 ] )
226+ XCTAssertEqual ( tensor. strides, [ 1 ] )
227+ XCTAssertEqual ( tensor. dimensionOrder, [ 0 ] )
228+ XCTAssertEqual ( tensor. shapeDynamism, . dynamicBound)
229+ XCTAssertEqual ( tensor. count, 4 )
230+ data [ 2 ] = 42.0
231+ XCTAssertEqual ( try tensor. scalars ( ) , data)
232+ }
233+
234+ func testInitScalarsNoCopyWithExplicitParams( ) throws {
235+ var data : [ Int ] = [ 10 , 20 , 30 , 40 ]
236+ let tensor = Tensor (
237+ & data,
238+ shape: [ 2 , 2 ] ,
239+ strides: [ 2 , 1 ] ,
240+ dimensionOrder: [ 1 , 0 ] ,
241+ shapeDynamism: . static
242+ )
243+ XCTAssertEqual ( tensor. dataType, . long)
244+ XCTAssertEqual ( tensor. shape, [ 2 , 2 ] )
245+ XCTAssertEqual ( tensor. strides, [ 2 , 1 ] )
246+ XCTAssertEqual ( tensor. dimensionOrder, [ 1 , 0 ] )
247+ XCTAssertEqual ( tensor. shapeDynamism, . static)
248+ XCTAssertEqual ( tensor. count, 4 )
249+ data = data. map { - $0 }
250+ XCTAssertEqual ( try tensor. scalars ( ) , data)
251+ }
252+
253+ func testInitScalarsBoolNoCopy( ) throws {
254+ var data : [ Bool ] = [ true , false , true ]
255+ let tensor = Tensor ( & data)
256+
257+ XCTAssertEqual ( tensor. dataType, . bool)
258+ XCTAssertEqual ( tensor. shape, [ 3 ] )
259+ XCTAssertEqual ( tensor. strides, [ 1 ] )
260+ XCTAssertEqual ( tensor. dimensionOrder, [ 0 ] )
261+ XCTAssertEqual ( tensor. count, 3 )
262+ data [ 1 ] . toggle ( )
263+ XCTAssertEqual ( try tensor. scalars ( ) , data)
264+ }
265+
220266 func testInitScalarsUInt8( ) {
221267 let data : [ UInt8 ] = [ 1 , 2 , 3 , 4 , 5 , 6 ]
222268 let tensor = Tensor ( data)
0 commit comments