@@ -45,77 +45,83 @@ using .MPS: MPSMatrix
4545 rowBytes = sizeof (T) * cols
4646 mats = 4
4747
48- desc = MPSMatrixDescriptor (rows, cols, rowBytes, T)
49- devmat = MPSMatrix (dev, desc)
50- @test devmat isa MPSMatrix
51- @test devmat. device == dev
52- @test devmat. rows == rows
53- @test devmat. columns == cols
54- @test devmat. rowBytes == rowBytes
55- @test devmat. matrices == 1
56- @test devmat. dataType == DT
57- @test devmat. matrixBytes == rowBytes * rows
58- @test devmat. offset == 0
59-
60- mat = MtlMatrix {T} (undef, rows, cols)
61- acols, arows = size (mat)
62- arowBytes = sizeof (T) * acols
63- abufmat = MPSMatrix (mat)
64- @test abufmat isa MPSMatrix
65- @test abufmat. device == dev
66- @test abufmat. rows == arows
67- @test abufmat. columns == acols
68- @test abufmat. rowBytes == arowBytes
69- @test abufmat. matrices == 1
70- @test abufmat. dataType == DT
71- @test abufmat. matrixBytes == arowBytes * arows
72- @test abufmat. offset == 0
73- @test abufmat. data == mat. data[]
74-
75- vmat = @view mat[:, 2 : 3 ]
76- vcols, vrows = size (vmat)
77- vrowBytes = sizeof (T) * vcols
78- vbufmat = MPSMatrix (vmat)
79- @test vbufmat isa MPSMatrix
80- @test vbufmat. device == dev
81- @test vbufmat. rows == vrows
82- @test vbufmat. columns == vcols
83- @test vbufmat. rowBytes == vrowBytes
84- @test vbufmat. matrices == 1
85- @test vbufmat. dataType == DT
86- @test vbufmat. matrixBytes == vrowBytes * vrows
87- @test vbufmat. offset == vmat. offset * sizeof (T)
88- @test vbufmat. data == vmat. data[]
89-
90- arr = MtlArray {T,3} (undef, rows, cols, mats)
91- mcols, mrows, mmats = size (arr)
92- mrowBytes = sizeof (T) * mcols
93- mpsmat = MPSMatrix (mat)
94- @test mpsmat isa MPSMatrix
95- @test mpsmat. device == dev
96- @test mpsmat. rows == mrows
97- @test mpsmat. columns == mcols
98- @test mpsmat. rowBytes == mrowBytes
99- @test mpsmat. matrices == 1
100- @test mpsmat. dataType == DT
101- @test mpsmat. matrixBytes == mrowBytes * mrows
102- @test mpsmat. offset == 0
103- @test mpsmat. data == mat. data[]
104-
105- vec = MtlVector {T} (undef, rows)
106- veccols, vecrows = length (vec), 1
107- vecrowBytes = sizeof (T)* veccols
108- vmpsmat = MPSMatrix (vec)
109- @test vmpsmat isa MPSMatrix
110- @test vmpsmat. device == dev
111- @test vmpsmat. rows == vecrows
112- @test vmpsmat. columns == veccols
113- @test vmpsmat. rowBytes == vecrowBytes
114- @test vmpsmat. matrices == 1
115- @test vmpsmat. dataType == DT
116- @test vmpsmat. matrixBytes == vecrowBytes* vecrows
117- @test vmpsmat. offset == 0
118- @test vmpsmat. data == vec. data[]
48+ let desc = MPSMatrixDescriptor (rows, cols, rowBytes, T)
49+ devmat = MPSMatrix (dev, desc)
50+ @test devmat isa MPSMatrix
51+ @test devmat. device == dev
52+ @test devmat. rows == rows
53+ @test devmat. columns == cols
54+ @test devmat. rowBytes == rowBytes
55+ @test devmat. matrices == 1
56+ @test devmat. dataType == DT
57+ @test devmat. matrixBytes == rowBytes * rows
58+ @test devmat. offset == 0
59+ @test size (devmat) == (rows, cols)
60+ end
61+
62+ let mat = MtlMatrix {T} (undef, rows, cols)
63+ acols, arows = size (mat)
64+ arowBytes = sizeof (T) * acols
65+ abufmat = MPSMatrix (mat)
66+ @test abufmat isa MPSMatrix
67+ @test abufmat. device == dev
68+ @test abufmat. rows == arows
69+ @test abufmat. columns == acols
70+ @test abufmat. rowBytes == arowBytes
71+ @test abufmat. matrices == 1
72+ @test abufmat. dataType == DT
73+ @test abufmat. matrixBytes == arowBytes * arows
74+ @test abufmat. offset == 0
75+ @test abufmat. data == mat. data[]
76+
77+ vmat = @view mat[:, 2 : 3 ]
78+ vcols, vrows = size (vmat)
79+ vrowBytes = sizeof (T) * vcols
80+ vbufmat = MPSMatrix (vmat)
81+ @test vbufmat isa MPSMatrix
82+ @test vbufmat. device == dev
83+ @test vbufmat. rows == vrows
84+ @test vbufmat. columns == vcols
85+ @test vbufmat. rowBytes == vrowBytes
86+ @test vbufmat. matrices == 1
87+ @test vbufmat. dataType == DT
88+ @test vbufmat. matrixBytes == vrowBytes * vrows
89+ @test vbufmat. offset == vmat. offset * sizeof (T)
90+ @test vbufmat. data == vmat. data[]
91+ end
92+
93+ let arr = MtlArray {T, 3} (undef, rows, cols, mats)
94+ mcols, mrows, mmats = size (arr)
95+ mrowBytes = sizeof (T) * mcols
96+ mpsmat = MPSMatrix (arr)
97+ @test mpsmat isa MPSMatrix
98+ @test mpsmat. device == dev
99+ @test mpsmat. rows == mrows
100+ @test mpsmat. columns == mcols
101+ @test mpsmat. rowBytes == mrowBytes
102+ @test mpsmat. matrices == mmats
103+ @test mpsmat. dataType == DT
104+ @test mpsmat. matrixBytes == mrowBytes * mrows
105+ @test mpsmat. offset == 0
106+ @test mpsmat. data == arr. data[]
107+ @test size (mpsmat) == (mmats, mrows, mcols)
108+ end
109+
110+ let vec = MtlVector {T} (undef, rows)
111+ veccols, vecrows = length (vec), 1
112+ vecrowBytes = sizeof (T) * veccols
113+ vmpsmat = MPSMatrix (vec)
114+ @test vmpsmat isa MPSMatrix
115+ @test vmpsmat. device == dev
116+ @test vmpsmat. rows == vecrows
117+ @test vmpsmat. columns == veccols
118+ @test vmpsmat. rowBytes == vecrowBytes
119+ @test vmpsmat. matrices == 1
120+ @test vmpsmat. dataType == DT
121+ @test vmpsmat. matrixBytes == vecrowBytes * vecrows
122+ @test vmpsmat. offset == 0
123+ @test vmpsmat. data == vec. data[]
124+ end
119125end
120126
121127
0 commit comments