Skip to content

Commit a874ef7

Browse files
authored
feat: remove closure restriction on batching (#1443)
1 parent 148e1c2 commit a874ef7

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.142"
4+
version = "0.2.143"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Ops.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2876,7 +2876,6 @@ end
28762876
args_in_result=:none,
28772877
do_transpose=false,
28782878
)
2879-
@assert !mlir_fn_res.fnwrapped "Currently we don't support batching closures."
28802879

28812880
func = mlir_fn_res.f
28822881
@assert MLIR.IR.nregions(func) == 1

test/batching.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@ using Reactant, Test
22

33
f1(x::AbstractMatrix) = sum(x; dims=1)
44

5+
f2(x::AbstractMatrix, y::Int) = x .+ y
6+
7+
function f3(x::AbstractArray{T,3}, y::Int) where {T}
8+
return Reactant.Ops.batch(Base.Fix2(f2, y), x, [1, 2])
9+
end
10+
511
@testset "mapslices" begin
612
A = collect(reshape(1:30, (2, 5, 3)))
713
A_ra = Reactant.to_rarray(A)
@@ -16,3 +22,10 @@ f1(x::AbstractMatrix) = sum(x; dims=1)
1622

1723
@test B B_ra
1824
end
25+
26+
@testset "closure" begin
27+
A = collect(reshape(1:30, (2, 5, 3)))
28+
A_ra = Reactant.to_rarray(A)
29+
30+
@test @jit(f3(A_ra, 1)) A .+ 1
31+
end

0 commit comments

Comments
 (0)