Skip to content

Commit 7c20848

Browse files
authored
fix: materialize LinRange properly (#1389)
1 parent 0c8946c commit 7c20848

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.131"
4+
version = "0.2.132"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/TracedUtils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ function ReactantCore.materialize_traced_array(x::AbstractRange)
2626
return Reactant.aos_to_soa(collect(x))
2727
end
2828

29+
function ReactantCore.materialize_traced_array(r::LinRange)
30+
T = Reactant.unwrapped_eltype(r)
31+
idxs = Ops.iota(T, [length(r)]; iota_dimension=1)
32+
t = idxs ./ r.lendiv
33+
return T.((1 .- t) .* r.start .+ t .* r.stop)
34+
end
35+
2936
function ReactantCore.materialize_traced_array(x::Base.OneTo)
3037
return Ops.iota(Reactant.unwrapped_eltype(x), [length(x)]; iota_dimension=1)
3138
end

test/basic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,3 +1347,16 @@ end
13471347
@test @jit(circshift(x_ra, (-3, 2))) circshift(x, (-3, 2))
13481348
@test @jit(circshift(x_ra, (5, 2))) circshift(x, (5, 2))
13491349
end
1350+
1351+
linrange_mat(x1, x2) = Reactant.materialize_traced_array(LinRange(x1, x2, 10024))
1352+
1353+
@testset "LinRange" begin
1354+
x1 = 0.0f0
1355+
x2 = 1.0f0
1356+
x1_ra = Reactant.to_rarray(x1; track_numbers=Number)
1357+
x2_ra = Reactant.to_rarray(x2; track_numbers=Number)
1358+
1359+
@test @jit(linrange_mat(x1_ra, x2_ra)) collect(LinRange(x1, x2, 10024))
1360+
hlo = repr(@code_hlo(linrange_mat(x1_ra, x2_ra)))
1361+
@test contains(hlo, "stablehlo.iota")
1362+
end

0 commit comments

Comments
 (0)