Skip to content

Commit ef63aa4

Browse files
authored
Merge branch 'main' into wsmoses-patch-4
2 parents aba9c6a + d91b736 commit ef63aa4

File tree

9 files changed

+66
-8
lines changed

9 files changed

+66
-8
lines changed

deps/build_local.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ else
172172
push!(build_cmd_list, "--define=using_clang=true")
173173
push!(build_cmd_list, "--copt=-Wno-unused-command-line-argument")
174174
end
175+
push!(build_cmd_list, "--copt=-Wno-private-header")
175176
push!(build_cmd_list, "--color=$(parsed_args["color"])")
176177
push!(build_cmd_list, ":libReactantExtra.so")
177178

ext/ReactantCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,7 @@ end
14891489
end
14901490
y = Reactant.ConcreteRArray([2.0]; client)
14911491
Reactant.Compiler.compile_mlir(square!, (y,); optimize=false)
1492+
finalize(y)
14921493
end
14931494
end
14941495

src/Ops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,7 @@ function broadcast_in_dim(
10641064
location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__),
10651065
) where {T,N}
10661066
@assert length(dims) == N
1067+
@assert length(result_size) N
10671068

10681069
res = MLIR.IR.result(
10691070
stablehlo.broadcast_in_dim(

src/Precompile.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ if Reactant_jll.is_available()
6969
@static if precompilation_supported()
7070
x = ConcreteRNumber(2.0; client)
7171
Reactant.compile(sin, (x,); client, optimize=:all)
72+
finalize(x)
7273

7374
y = ConcreteRArray([2.0]; client)
7475
Reactant.compile(Base.sum, (y,); client, optimize=:all)
76+
finalize(y)
7577
end
7678
end
7779

src/TracedRArray.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,14 @@ end
543543
function overloaded_mapreduce(
544544
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
545545
)
546-
return overloaded_mapreduce(identity, op, unwrapped_broadcast(f, A); dims, init)
546+
res = unwrapped_broadcast(f, A)
547+
# This means we are unable to use the optimized dispatches. For now we will
548+
# unroll the mapreduce.
549+
if typeof(res) == typeof(A)
550+
@assert dims == Colon() "dims not supported for mapreduce currently."
551+
return foldl(op, res; init)
552+
end
553+
return overloaded_mapreduce(identity, op, res; dims=:, init)
547554
end
548555

549556
function overloaded_mapreduce(
@@ -1436,16 +1443,27 @@ struct BroadcastIterator{F}
14361443
f::F
14371444
end
14381445

1439-
(fn::BroadcastIterator)(args...) = fn.f((args...,))
1446+
(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,))
14401447

14411448
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
14421449
min_length = Base.inferencebarrier(minimum)(length, x.is)
14431450
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
1444-
return (BroadcastIterator(f)).(itrs...)
1451+
if any(Base.Fix2(isa, AnyTracedRArray), itrs)
1452+
return (BroadcastIterator(f)).(itrs...)
1453+
else
1454+
fn = BroadcastIterator(f)
1455+
return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length]
1456+
end
14451457
end
14461458

14471459
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
1448-
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
1460+
if x.itr isa AnyTracedRArray
1461+
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
1462+
else
1463+
return [f((i, x.itr[i])) for i in 1:length(x.itr)]
1464+
end
14491465
end
14501466

1467+
unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs]
1468+
14511469
end

src/xla/Buffer.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
abstract type AbstractBuffer end
22

3+
function free_buffer end
34
function synced_buffer end
45
function buffer_on_cpu end
56
function to_host end

src/xla/IFRT/Array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mutable struct Array <: XLA.AbstractBuffer
33

44
function Array(buffer::Ptr{Cvoid}, owned::Bool=true)
55
!owned && return new(buffer)
6-
return finalizer(free_ifrt_array, new(buffer))
6+
return finalizer(XLA.free_buffer, new(buffer))
77
end
88
end
99

@@ -158,7 +158,7 @@ function Array(
158158
return Array(client, array, ifrt_sharding)
159159
end
160160

161-
@inline function free_ifrt_array(buffer::Array)
161+
@inline function XLA.free_buffer(buffer::Array)
162162
if buffer.buffer != C_NULL
163163
@ccall MLIR.API.mlir_c.ifrt_free_array(buffer.buffer::Ptr{Cvoid})::Cvoid
164164
end

src/xla/PJRT/Buffer.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mutable struct Buffer <: XLA.AbstractBuffer
22
buffer::Ptr{Cvoid}
33

44
function Buffer(buffer::Ptr{Cvoid})
5-
return finalizer(free_buffer, new(buffer))
5+
return finalizer(XLA.free_buffer, new(buffer))
66
end
77
end
88

@@ -114,7 +114,7 @@ function Base.similar(a::Buffer, S::Type, dims::Dims)
114114
return Base.similar(Buffer, S, dims; client=XLA.client(a), device=XLA.device(a))
115115
end
116116

117-
@inline function free_buffer(buffer::Buffer)
117+
@inline function XLA.free_buffer(buffer::Buffer)
118118
sbuffer = buffer.buffer
119119
if sbuffer != C_NULL
120120
@ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid

test/basic.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,20 @@ end
14441444
zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b))
14451445
enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))
14461446

1447+
function nested_mapreduce_zip(x, y)
1448+
return mapreduce(+, zip(eachcol(x), eachcol(y)); init=0.0f0) do (x, y)
1449+
return sum(abs2, x) + sum(abs2, y)
1450+
end
1451+
end
1452+
1453+
function nested_mapreduce_hcat(x, y)
1454+
return mapreduce(
1455+
hcat, zip(eachcol(x), eachcol(y)); init=similar(x, size(x, 1), 0)
1456+
) do (x, y)
1457+
return x .+ y
1458+
end
1459+
end
1460+
14471461
@testset "Base.Iterators" begin
14481462
@testset "zip" begin
14491463
N = 10
@@ -1460,4 +1474,24 @@ enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))
14601474

14611475
@test @jit(enumerate_iterator(x_ra)) enumerate_iterator(x)
14621476
end
1477+
1478+
@testset "nested mapreduce" begin
1479+
x = rand(Float32, 4, 3)
1480+
y = rand(Float32, 4, 3)
1481+
1482+
x_ra = Reactant.to_rarray(x)
1483+
y_ra = Reactant.to_rarray(y)
1484+
1485+
@test @jit(nested_mapreduce_zip(x_ra, y_ra)) nested_mapreduce_zip(x, y)
1486+
end
1487+
1488+
@testset "nested mapreduce hcat" begin
1489+
x = rand(Float32, 4, 3)
1490+
y = rand(Float32, 4, 3)
1491+
1492+
x_ra = Reactant.to_rarray(x)
1493+
y_ra = Reactant.to_rarray(y)
1494+
1495+
@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) nested_mapreduce_hcat(x, y)
1496+
end
14631497
end

0 commit comments

Comments
 (0)