Skip to content

Commit a92246c

Browse files
avik-palwsmoses
andauthored
feat: implementing view/getindex/setindex (#104)
* feat: implementing view/getindex/setindex * fix: try passing in T,N directly * fix * fix * fix * Fix * Fix * Fix * Fix * Fix * fix --------- Co-authored-by: William S. Moses <[email protected]>
1 parent e86553c commit a92246c

File tree

3 files changed

+138
-43
lines changed

3 files changed

+138
-43
lines changed

src/TracedRArray.jl

Lines changed: 96 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1616
end
1717
end
1818

19+
function Base.getindex(a::TracedRArray{T,0}) where T
20+
return a
21+
end
22+
1923
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Integer,N}) where {T,N}
2024
@warn(
2125
"""Performing scalar indexing on task $(current_task()).
@@ -43,6 +47,46 @@ and require expensive copies and synchronization each time and therefore should
4347
return TracedRArray{T,0}((), res2, ())
4448
end
4549

50+
function Base.getindex(
51+
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
52+
) where {T,N}
53+
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
54+
res = MLIR.IR.result(
55+
MLIR.Dialects.stablehlo.slice(
56+
a.mlir_data;
57+
start_indices=MLIR.IR.DenseArrayAttribute([
58+
Int64(first(i) - 1) for i in indices
59+
]),
60+
limit_indices=MLIR.IR.DenseArrayAttribute([Int64(last(i)) for i in indices]),
61+
strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in indices]),
62+
),
63+
1,
64+
)
65+
return TracedRArray{T,N}((), res, Tuple(length.(indices)))
66+
end
67+
68+
function Base.view(
69+
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
70+
) where {T,N}
71+
# TODO: Implement before merging the PR
72+
return error("view is not supported yet")
73+
end
74+
75+
function Base.setindex!(
76+
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
77+
) where {T,N}
78+
indices = [(promote_to(TracedRArray{Int, 0}, i isa Colon ? 1 : first(i))-1).mlir_data for i in indices]
79+
v = promote_to(TracedRArray{T,N}, v)
80+
res = MLIR.IR.result(
81+
MLIR.Dialects.stablehlo.dynamic_update_slice(
82+
a.mlir_data, v.mlir_data, indices
83+
),
84+
1,
85+
)
86+
a.mlir_data = res
87+
return v
88+
end
89+
4690
Base.size(x::TracedRArray) = x.shape
4791

4892
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray((), A.mlir_data, size(A))
@@ -118,6 +162,9 @@ end
118162

119163
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
120164
if isa(rhs, TracedRArray)
165+
if typeof(rhs) == TracedRArray{T, N}
166+
return rhs
167+
end
121168
return TracedRArray{T,N}(
122169
(),
123170
MLIR.IR.result(
@@ -136,10 +183,11 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
136183
)
137184
return ta
138185
end
139-
attr = MLIR.IR.DenseElementsAttribute(mlir_type(TracedRArray{T,N}, size(rhs)), rhs)
140-
return TracedRArray{T,N}(
186+
T0 = eltype(rhs)
187+
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
188+
return promote_to(TracedRArray{T, N}, TracedRArray{T0,length(size(rhs))}(
141189
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
142-
)
190+
))
143191
end
144192

145193
function promote_to(lhs::TracedRArray{T,N}, rhs) where {T,N}
@@ -410,13 +458,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
410458
return traced2_result
411459
end
412460

413-
for (jlop, hloop, hlocomp) in (
414-
(:(Base.:(==)), :compare, "EQ"),
415-
(:(Base.:(!=)), :compare, "NE"),
416-
(:(Base.:(>=)), :compare, "GE"),
417-
(:(Base.:(>)), :compare, "GT"),
418-
(:(Base.:(<=)), :compare, "LE"),
419-
(:(Base.:(<)), :compare, "LT"),
461+
for (jlop, hloop, hlocomp, merge) in (
462+
(:(Base.:(==)), :compare, "EQ", :all),
463+
(:(Base.:(!=)), :compare, "NE", :any),
464+
(:(Base.:(>=)), :compare, "GE", nothing),
465+
(:(Base.:(>)), :compare, "GT", nothing),
466+
(:(Base.:(<=)), :compare, "LE", nothing),
467+
(:(Base.:(<)), :compare, "LT", nothing),
420468
)
421469
@eval begin
422470
function elem_apply(
@@ -441,43 +489,50 @@ for (jlop, hloop, hlocomp) in (
441489
end
442490

443491
function elem_apply(
444-
::typeof($jlop), @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
492+
fn::typeof($jlop), @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
445493
) where {T,N}
446-
rhs = promote_to(lhs, rhs)
447-
return TracedRArray{T,N}(
448-
(),
449-
MLIR.IR.result(
450-
MLIR.Dialects.stablehlo.$hloop(
451-
lhs.mlir_data,
452-
rhs.mlir_data;
453-
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
454-
MLIR.IR.context(), $hlocomp
455-
),
456-
),
457-
1,
458-
),
459-
size(lhs),
460-
)
494+
elem_apply(fn, lhs, promote_to(lhs, rhs))
461495
end
462496

463497
function elem_apply(
464498
::typeof($jlop), @nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,N})
465499
) where {T,N}
466-
lhs = promote_to(rhs, lhs)
467-
return TracedRArray{T,N}(
468-
(),
469-
MLIR.IR.result(
470-
MLIR.Dialects.stablehlo.$hloop(
471-
lhs.mlir_data,
472-
rhs.mlir_data;
473-
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
474-
MLIR.IR.context(), $hlocomp
475-
),
476-
),
477-
1,
478-
),
479-
size(lhs),
480-
)
500+
elem_apply(fn, promote_to(rhs, lhs), rhs)
501+
end
502+
503+
function $jlop(@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
504+
) where {T, N}
505+
$jlop(lhs, promote_to(lhs, rhs))
506+
end
507+
508+
function $jlop(@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,N})
509+
) where {T, N}
510+
$jlop(promote_to(rhs, lhs), rhs)
511+
end
512+
end
513+
514+
if merge != nothing
515+
@eval begin
516+
function $jlop(
517+
@nospecialize(lhs::TracedRArray{T,N}),
518+
@nospecialize(rhs::TracedRArray{T,N})
519+
) where {T,N}
520+
elems = elem_apply($jlop, lhs, rhs)
521+
if N == 0
522+
elems
523+
else
524+
$merge(elems)
525+
end
526+
end
527+
end
528+
else
529+
@eval begin
530+
function $jlop(
531+
@nospecialize(lhs::TracedRArray{T,0}),
532+
@nospecialize(rhs::TracedRArray{T,0})
533+
) where {T}
534+
elem_apply($jlop, lhs, rhs)
535+
end
481536
end
482537
end
483538
end

src/Tracing.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,15 @@ function traced_type(::Type{T}, seen, mode) where {T<:Function}
3838

3939
# in closures, enclosured variables need to be traced
4040
N = fieldcount(T)
41+
changed = false
4142
traced_fieldtypes = ntuple(Val(N)) do i
42-
return traced_type(fieldtype(T, i), seen, mode)
43+
next = traced_type(fieldtype(T, i), seen, mode)
44+
changed |= next != fieldtype(T, i)
45+
next
46+
end
47+
48+
if !changed
49+
return T
4350
end
4451

4552
# closure are struct types with the types of enclosured vars as type parameters
@@ -426,7 +433,7 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
426433
end
427434
prev2 = prev.contents
428435
tr = make_tracer(seen, prev2, append_path(path, :contents), mode; kwargs...)
429-
if tr == prev2
436+
if tr === prev2
430437
seen[prev] = prev
431438
return prev
432439
end

test/basic.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,36 @@ end
212212
@test cat2(x) cat2_compiled(x_concrete)
213213
@test cat3(x) cat3_compiled(x_concrete)
214214
end
215+
216+
function update_on_copy(x)
217+
y = x[1:2, 2:4, :]
218+
y[1:1, 1:1, :] = ones(1, 1, 3)
219+
return y
220+
end
221+
222+
@testset "view / setindex" begin
223+
x = rand(2, 4, 3)
224+
y = copy(x)
225+
x_concrete = Reactant.to_rarray(x)
226+
y_concrete = Reactant.to_rarray(y)
227+
228+
update_on_copy_compiled = Reactant.compile(update_on_copy, (x_concrete,))
229+
230+
y1 = update_on_copy(x)
231+
y2 = update_on_copy_compiled(x_concrete)
232+
@test x == y
233+
@test x_concrete == y_concrete
234+
@test y1 == y2
235+
236+
# function update_inplace(x)
237+
# y = view(x, 1:2, 1:2, :)
238+
# y[1, 1, :] .= 1
239+
# return y
240+
# end
241+
242+
# get_indices(x) = x[1:2, 1:2, :]
243+
# get_view(x) = view(x, 1:2, 1:2, :)
244+
245+
# get_indices_compiled = Reactant.compile(get_indices, (x_concrete,))
246+
# get_view_compiled = Reactant.compile(get_view, (x_concrete,))
247+
end

0 commit comments

Comments
 (0)