Skip to content

Commit 6f2269e

Browse files
Fix MPS.synchronize_state (#434)
1 parent 71b784e commit 6f2269e

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

lib/mps/matrixrandom.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ for R in [:MPSMatrixRandomMTGP32, :MPSMatrixRandomPhilox]
107107
end
108108

109109
synchronize_state(kern::MPSMatrixRandomMTGP32, cmdbuf::MTLCommandBuffer) =
110-
@objc [obj::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing
110+
@objc [kern::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing
111111

112112

113113
@inline function _mpsmat_rand!(randkern::MPSMatrixRandom, dest::MtlArray{T}, ::Type{T2};

test/mps/matrix.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,10 @@ using .MPS: MPSMatrixFindTopK
170170
@test topk.sourceColumns == cols
171171
@test topk.sourceRows == rows
172172
end
173+
174+
# Ensure that the function does not error
175+
@testset "MPSMatrixRandom sync state" begin
176+
cmdbuf = MTL.MTLCommandBuffer(global_queue(device()))
177+
rng = MPS.MPSMatrixRandomMTGP32(device())
178+
@test isnothing(MPS.synchronize_state(rng, cmdbuf))
179+
end

0 commit comments

Comments
 (0)