Skip to content

Commit 24f9acb

Browse files
authored
feat: generalize dot (#1393)
1 parent f716f6b commit 24f9acb

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,12 @@ function LinearAlgebra.dot(x::AnyTracedRVector, y::AnyTracedRVector)
491491
return TracedRNumber{unwrapped_eltype(res)}((), res.mlir_data)
492492
end
493493

494+
LinearAlgebra.dot(x::AnyTracedRArray, y::AnyTracedRArray) = dot(vec(x), vec(y))
495+
496+
function LinearAlgebra.dot(x::AnyTracedRVector, A::AnyTracedRMatrix, y::AnyTracedRVector)
497+
return dot(x, A * y)
498+
end
499+
494500
# ldiv & rdiv interfaces
495501
tfun_to_char(::typeof(identity)) = 'N'
496502
tfun_to_char(::typeof(transpose)) = 'T'

test/integration/linear_algebra.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,19 +265,44 @@ end
265265
end
266266

267267
@testset "Dot" begin
268-
x = collect(Float32, 1:10)
269-
y = collect(Float32, 10:-1:1)
270-
x_ra = Reactant.to_rarray(x)
271-
y_ra = Reactant.to_rarray(y)
268+
@testset "2-arg real" begin
269+
x = collect(Float32, 1:10)
270+
y = collect(Float32, 10:-1:1)
271+
x_ra = Reactant.to_rarray(x)
272+
y_ra = Reactant.to_rarray(y)
272273

273-
@test @jit(dot(x_ra, y_ra)) dot(x, y)
274+
@test @jit(dot(x_ra, y_ra)) dot(x, y)
274275

275-
x = rand(Complex{Float32}, 4)
276-
y = rand(Complex{Float32}, 4)
277-
x_ra = Reactant.to_rarray(x)
278-
y_ra = Reactant.to_rarray(y)
276+
x = reshape(collect(Float32, 1:10), 2, 5)
277+
x_ra = Reactant.to_rarray(x)
278+
279+
@test @jit(dot(x_ra, x_ra)) dot(x, x)
280+
end
279281

280-
@test @jit(dot(x_ra, y_ra)) dot(x, y)
282+
@testset "2-arg complex" begin
283+
x = rand(Complex{Float32}, 4)
284+
y = rand(Complex{Float32}, 4)
285+
x_ra = Reactant.to_rarray(x)
286+
y_ra = Reactant.to_rarray(y)
287+
288+
@test @jit(dot(x_ra, y_ra)) dot(x, y)
289+
290+
x = rand(Complex{Float32}, 2, 2)
291+
x_ra = Reactant.to_rarray(x)
292+
293+
@test @jit(dot(x_ra, x_ra)) dot(x, x)
294+
end
295+
296+
@testset "3-arg" begin
297+
x = rand(Float32, 2, 2)
298+
y = rand(Float32, 4, 5)
299+
z = rand(Float32, 5)
300+
x_ra = Reactant.to_rarray(x)
301+
y_ra = Reactant.to_rarray(y)
302+
z_ra = Reactant.to_rarray(z)
303+
304+
@test @jit(dot(x_ra, y_ra, z_ra)) dot(x, y, z)
305+
end
281306
end
282307

283308
@testset "Triangular ldiv and rdiv" begin

0 commit comments

Comments
 (0)