Skip to content

Commit d37e9dd

Browse files
Couple typos and is_m4 function (#498)
[skip benchmarks]
1 parent 60a9e34 commit d37e9dd

File tree

4 files changed

+14
-13
lines changed

4 files changed

+14
-13
lines changed

lib/mps/matrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ function encode!(cmdbuf::MTLCommandBuffer, matmul::MPSMatrixMultiplication, left
211211
end
212212

213213
"""
214-
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
214+
matmul!(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
215215
transpose_left=false, transpose_right=false)
216216
A `MPSMatrixMultiplication` kernel thay computes:
217217
`c = alpha * op(a) * beta * op(b) + beta * C`

lib/mps/matrixrandom.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export MPSMatrixRandomDistributionDescriptor
1616
@autoproperty distributionType::MPSMatrixRandomDistribution
1717
@autoproperty maximum::Float32 setter=setMaximum
1818
@autoproperty mean::Float32 setter=setMean
19-
@autoproperty minimum::Float32 setter=setMimimum
19+
@autoproperty minimum::Float32 setter=setMinimum
2020
@autoproperty standardDeviation::Float32 setter=setStandardDeviation
2121
end
2222

lib/mps/ndarray.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ end
8282
end
8383

8484
function MPSTemporaryNDArray(cmdbuf::MTLCommandBuffer, descriptor::MPSNDArrayDescriptor)
85-
@objc [MPSNDTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
85+
@objc [MPSTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
8686
descriptor:descriptor::id{MPSNDArrayDescriptor}]::id{MPSTemporaryNDArray}
8787
return obj
8888
end
@@ -123,7 +123,7 @@ end
123123
return obj
124124
end
125125
else
126-
function MPSNDArray(buffer::MTLBuffer, offset::UInt, descriptor::MPSNDArrayDescriptor)
126+
function MPSNDArray(_::MTLBuffer, _::UInt, _::MPSNDArrayDescriptor)
127127
@assert false "Creating an MPSNDArray that shares data with user-provided MTLBuffer is only supported in macOS v15+"
128128
end
129129
end
@@ -135,20 +135,18 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
135135
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
136136
end
137137

138-
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode)
138+
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
139139
ndims = Int(ndarr.numberOfDimensions)
140140
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
141141
T = convert(DataType, ndarr.dataType)
142142
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
143143
dev = device(arr)
144144

145-
cmdBuf = MTLCommandBuffer(global_queue(dev))
146-
147-
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
148-
149-
commit!(cmdBuf)
150-
wait_completed(cmdBuf)
145+
cmdBuf = MTLCommandBuffer(global_queue(dev)) do cmdBuf
146+
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
147+
end
151148

149+
async || wait_completed(cmdBuf)
152150
return arr
153151
end
154152

lib/mtl/device.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ MTLDevice(i::Integer) = devices()[i]
9191
# family
9292
#
9393

94-
export supports_family, is_m3, is_m2, is_m1
94+
export supports_family, is_m4, is_m3, is_m2, is_m1
9595

9696
@cenum MTLGPUFamily::NSInteger begin
9797
MTLGPUFamilyMetal3 = 5001 # Metal 3 support
@@ -121,4 +121,7 @@ is_m1(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple7) &&
121121
!supports_family(dev, MTLGPUFamilyApple8)
122122
is_m2(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple8) &&
123123
!supports_family(dev, MTLGPUFamilyApple9)
124-
is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9)
124+
is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) &&
125+
occursin("M3", String(dev.name))
126+
is_m4(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) &&
127+
occursin("M4", String(dev.name))

0 commit comments

Comments
 (0)