diff --git a/src/Compiler.jl b/src/Compiler.jl index 56c5b413dd..025a171b30 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -691,7 +691,7 @@ const SUM_TO_CONV = Ref(false) const AGGRESSIVE_SUM_TO_CONV = Ref(false) const AGGRESSIVE_PROPAGATION = Ref(false) const DUS_SLICE_SIMPLIFY = Ref(true) -const CONCATS_TO_DUS = Ref(false) +const CONCATS_TO_DUS = Ref(true) # Optimization passes via transform dialect function optimization_passes( @@ -768,7 +768,6 @@ function optimization_passes( "iota_simplify<16>($max_constant_threshold)", "broadcast_in_dim_simplify<16>($max_constant_threshold)", "convert_concat<1>", - "dynamic_update_to_concat<1>", "slice_of_dynamic_update<1>", "slice_elementwise<1>", "dot_reshape_dot<1>", @@ -1177,8 +1176,16 @@ function optimization_passes( ) end + if !CONCATS_TO_DUS[] + push!(transform_passes_list, "dynamic_update_to_concat<1>") + end + lower_transform_passes = copy(transform_passes_list) + if CONCATS_TO_DUS[] + push!(transform_passes_list, "dynamic_update_to_concat<1>") + end + if recognize_comms append!( transform_passes_list, @@ -1198,6 +1205,14 @@ function optimization_passes( ], ",", ) + if CONCATS_TO_DUS[] + transform_passes = + transform_passes * + ",enzyme-hlo-generate-td{patterns=concat_to_onedim_dus},transform-interpreter,enzyme-hlo-remove-transform" + if lower_comms + push!(lower_transform_passes, "concat_to_onedim_dus") + end + end func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") if lower_comms func_passes = @@ -1206,12 +1221,6 @@ function optimization_passes( join(lower_transform_passes, ';') * "},transform-interpreter,enzyme-hlo-remove-transform" end - if CONCATS_TO_DUS[] - push!( - transform_passes_list, - "enzyme-hlo-generate-td{patterns=concat_to_onedim_dus},transform-interpreter,enzyme-hlo-remove-transform", - ) - end passes = String[] if compile_options.inline push!(passes, "inline{default-pipeline=canonicalize max-iterations=4}") diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 1eff7b4b9d..7e0c1bddff 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -406,7 +406,9 @@ end Base.similar(a::ConcretePJRTArray, dims::Dims) = similar(a, eltype(a), dims) -@inline function Base.similar(AT::Type{<:ConcretePJRTArray{T}}, dims; kwargs...) where {T} +@inline function Base.similar( + AT::Type{<:ConcretePJRTArray{T}}, dims::Dims; kwargs... +) where {T} return Base.similar(AT, T, dims; kwargs...) end @@ -416,7 +418,7 @@ function Base.similar(a::ConcreteIFRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) ) end Base.similar(a::ConcreteIFRTArray, dims::Dims) = similar(a, eltype(a), dims) -function Base.similar(::Type{ConcreteIFRTArray{T}}, dims) where {T} +function Base.similar(::Type{ConcreteIFRTArray{T}}, dims::Dims) where {T} return ConcreteIFRTArray(similar(Array{T}, dims)) end diff --git a/test/indexing.jl b/test/indexing.jl index 9487ecbe4a..9a4aad4a3c 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -47,6 +47,18 @@ end @test y ≈ Array(y_ra) end +function write_row_simple!(xs_mut, i, v) + xs_mut[:, i] = v + return nothing +end +@testset "setindex: DUS" begin + x_ra = similar(ConcreteRArray{Float64}, (76, 100)) + y_ra = similar(ConcreteRArray{Float64}, 76) + hlo = @code_hlo write_row_simple!(x_ra, 1, y_ra) + @test contains(repr(hlo), "dynamic_update_slice") + @test !contains(repr(hlo), "concatenate") +end + function maskset!(y, x) y[:] = x return nothing