@@ -18,7 +18,17 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
1818
1919export MPSMatrixDescriptor
2020
21- @objcwrapper MPSMatrixDescriptor <: NSObject
21+ @objcwrapper immutable= false MPSMatrixDescriptor <: NSObject
22+
23+ @objcproperties MPSMatrixDescriptor begin
24+ @autoproperty rows:: NSUInteger setter= setRows
25+ @autoproperty columns:: NSUInteger setter= setColumns
26+ @autoproperty matrices:: NSUInteger
27+ @autoproperty dataType:: MPSDataType setter= setDataType
28+ @autoproperty rowBytes:: NSUInteger setter= setRowBytes
29+ @autoproperty matrixBytes:: NSUInteger
30+ end
31+
2232
2333# Mapping from Julia types to the Performance Shader bitfields
2434const jl_typ_to_mps = Dict {DataType,MPSDataType} (
@@ -49,6 +59,17 @@ function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
4959 return obj
5060end
5161
62+ function MPSMatrixDescriptor (rows, columns, matrices, rowBytes, matrixBytes, dataType)
63+ desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows: rows:: NSUInteger
64+ columns: columns:: NSUInteger
65+ matrices: matrices:: NSUInteger
66+ rowBytes: rowBytes:: NSUInteger
67+ matrixBytes: matrixBytes:: NSUInteger
68+ dataType: jl_typ_to_mps[dataType]:: MPSDataType ]:: id{MPSMatrixDescriptor}
69+ obj = MPSMatrixDescriptor (desc)
70+ # XXX : who releases this object?
71+ return obj
72+ end
5273
5374#
5475# matrix object
@@ -58,6 +79,19 @@ export MPSMatrix
5879
5980@objcwrapper immutable= false MPSMatrix <: NSObject
6081
82+ @objcproperties MPSMatrix begin
83+ @autoproperty device:: id{MTLDevice}
84+ @autoproperty rows:: NSUInteger
85+ @autoproperty columns:: NSUInteger
86+ @autoproperty matrices:: NSUInteger
87+ @autoproperty dataType:: MPSDataType
88+ @autoproperty rowBytes:: NSUInteger
89+ @autoproperty matrixBytes:: NSUInteger
90+ @autoproperty offset:: NSUInteger
91+ @autoproperty data:: id{MTLBuffer}
92+ end
93+
94+
6195"""
6296 MPSMatrix(arr::MtlMatrix)
6397
@@ -71,13 +105,37 @@ function MPSMatrix(arr::MtlMatrix{T}) where T
71105 desc = MPSMatrixDescriptor (n_rows, n_cols, sizeof (T)* n_cols, T)
72106 mat = @objc [MPSMatrix alloc]:: id{MPSMatrix}
73107 obj = MPSMatrix (mat)
108+ offset = arr. offset * sizeof (T)
74109 finalizer (release, obj)
75110 @objc [obj:: id{MPSMatrix} initWithBuffer: arr:: id{MTLBuffer}
111+ offset: offset:: NSUInteger
76112 descriptor: desc:: id{MPSMatrixDescriptor} ]:: id{MPSMatrix}
77113 return obj
78114end
79115
80116
117+ """
118+ MPSMatrix(arr::MtlArray{T,3})
119+
120+ Metal batched matrix representation used in Performance Shaders.
121+
122+ Note that this results in a transposed view of the input,
123+ as Metal stores matrices row-major instead of column-major.
124+ """
125+ function MPSMatrix (arr:: MtlArray{T,3} ) where T
126+ n_cols, n_rows, n_matrices = size (arr)
127+ row_bytes = sizeof (T)* n_cols
128+ desc = MPSMatrixDescriptor (n_rows, n_cols, n_matrices, row_bytes, row_bytes * n_rows, T)
129+ mat = @objc [MPSMatrix alloc]:: id{MPSMatrix}
130+ obj = MPSMatrix (mat)
131+ offset = arr. offset * sizeof (T)
132+ finalizer (release, obj)
133+ @objc [obj:: id{MPSMatrix} initWithBuffer: arr:: id{MTLBuffer}
134+ offset: offset:: NSUInteger
135+ descriptor: desc:: id{MPSMatrixDescriptor} ]:: id{MPSMatrix}
136+ return obj
137+ end
138+
81139#
82140# matrix multiplication
83141#
0 commit comments