From 408037d01a57add1f39d2b6e86d339003efce7cf Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 27 Jul 2025 20:16:14 -0600 Subject: [PATCH 1/5] Default to post conversion to dynamic update slice --- src/Compiler.jl | 23 +++++++++++++++-------- test/indexing.jl | 13 +++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index a120081b13..4a76bd526c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -691,7 +691,7 @@ const SUM_TO_CONV = Ref(false) const AGGRESSIVE_SUM_TO_CONV = Ref(false) const AGGRESSIVE_PROPAGATION = Ref(false) const DUS_SLICE_SIMPLIFY = Ref(true) -const CONCATS_TO_DUS = Ref(false) +const CONCATS_TO_DUS = Ref(true) # Optimization passes via transform dialect function optimization_passes( @@ -768,7 +768,6 @@ function optimization_passes( "iota_simplify<16>($max_constant_threshold)", "broadcast_in_dim_simplify<16>($max_constant_threshold)", "convert_concat<1>", - "dynamic_update_to_concat<1>", "slice_of_dynamic_update<1>", "slice_elementwise<1>", "dot_reshape_dot<1>", @@ -1163,8 +1162,16 @@ function optimization_passes( ) end + if !CONCATS_TO_DUS[] + push!(transform_passes_list, "dynamic_update_to_concat<1>") + end + lower_transform_passes = copy(transform_passes_list) + if CONCATS_TO_DUS[] + push!(transform_passes_list, "dynamic_update_to_concat<1>") + end + if recognize_comms append!( transform_passes_list, @@ -1184,6 +1191,12 @@ function optimization_passes( ], ",", ) + if CONCATS_TO_DUS[] + transform_passes = transform_passes * ",enzyme-hlo-generate-td{patterns=concat_to_onedim_dus},transform-interpreter,enzyme-hlo-remove-transform" + if lower_comms + push!(lower_transform_passes, "concat_to_onedim_dus") + end + end func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") if lower_comms func_passes = @@ -1192,12 +1205,6 @@ function optimization_passes( join(lower_transform_passes, ';') * "},transform-interpreter,enzyme-hlo-remove-transform" end - if CONCATS_TO_DUS[] - push!( - transform_passes_list, - "enzyme-hlo-generate-td{patterns=concat_to_onedim_dus},transform-interpreter,enzyme-hlo-remove-transform", - ) - end passes = String[] if compile_options.inline push!(passes, "inline{default-pipeline=canonicalize max-iterations=4}") diff --git a/test/indexing.jl b/test/indexing.jl index 9487ecbe4a..ae3d9a32da 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -47,6 +47,19 @@ end @test y ≈ Array(y_ra) end +function write_row_simple!(xs_mut, i, v) + xs_mut[:, i] = v + nothing +end + +@testset "setindex: DUS" begin + x_ra = similar(ConcreteRArray{Float64}, 76, 100) + y_ra = similar(ConcreteRArray{Float64}, 76) + hlo = @code_hlo write_row_simple!(x_ra, 1, y_ra) + @test contains(repr(hlo), "dynamic_update_slice") + @test !contains(repr(hlo), "concatenate") +end + function maskset!(y, x) y[:] = x return nothing From ad064e5cd439ba5e19ea3cf2559057594647b156 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 27 Jul 2025 20:21:28 -0600 Subject: [PATCH 2/5] fix --- src/ConcreteRArray.jl | 2 +- test/indexing.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index d4c26b3843..1bf490921c 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -409,7 +409,7 @@ end Base.similar(a::ConcretePJRTArray, dims::Dims) = similar(a, eltype(a), dims) -@inline function Base.similar(AT::Type{<:ConcretePJRTArray{T}}, dims; kwargs...) where {T} +@inline function Base.similar(AT::Type{<:ConcretePJRTArray{T}}, dims::Dims; kwargs...) where {T} return Base.similar(AT, T, dims; kwargs...) end diff --git a/test/indexing.jl b/test/indexing.jl index ae3d9a32da..57a081092b 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -53,8 +53,8 @@ function write_row_simple!(xs_mut, i, v) end @testset "setindex: DUS" begin - x_ra = similar(ConcreteRArray{Float64}, 76, 100) - y_ra = similar(ConcreteRArray{Float64}, 76) + x_ra = similar(ConcreteRArray{Float64}, (76, 100)) + y_ra = similar(ConcreteRArray{Float64}, (76)) hlo = @code_hlo write_row_simple!(x_ra, 1, y_ra) @test contains(repr(hlo), "dynamic_update_slice") @test !contains(repr(hlo), "concatenate") From a7da3d00c18a7287b8e419a6dd3344edbbe60bf6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 27 Jul 2025 20:26:30 -0600 Subject: [PATCH 3/5] Update test/indexing.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> --- test/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/indexing.jl b/test/indexing.jl index 57a081092b..e8a54fac99 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -54,7 +54,7 @@ end @testset "setindex: DUS" begin x_ra = similar(ConcreteRArray{Float64}, (76, 100)) - y_ra = similar(ConcreteRArray{Float64}, (76)) + y_ra = similar(ConcreteRArray{Float64}, 76) hlo = @code_hlo write_row_simple!(x_ra, 1, y_ra) @test contains(repr(hlo), "dynamic_update_slice") @test !contains(repr(hlo), "concatenate") From 8338e9e106577041f9b89c6c8a9940a6880b2498 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Jul 2025 12:36:18 -0600 Subject: [PATCH 4/5] Update ConcreteRArray.jl --- src/ConcreteRArray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 1bf490921c..06ae860bbd 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -419,7 +419,7 @@ function Base.similar(a::ConcreteIFRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) ) end Base.similar(a::ConcreteIFRTArray, dims::Dims) = similar(a, eltype(a), dims) -function Base.similar(::Type{ConcreteIFRTArray{T}}, dims) where {T} +function Base.similar(::Type{ConcreteIFRTArray{T}}, dims::Dims) where {T} return ConcreteIFRTArray(similar(Array{T}, dims)) end From b2715208340c09805adb553033b5b130fc376ed0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 6 Aug 2025 00:43:14 -0400 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Compiler.jl | 10 ++++++---- src/ConcreteRArray.jl | 4 +++- test/indexing.jl | 5 ++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 688593ca87..025a171b30 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1177,13 +1177,13 @@ function optimization_passes( end if !CONCATS_TO_DUS[] - push!(transform_passes_list, "dynamic_update_to_concat<1>") + push!(transform_passes_list, "dynamic_update_to_concat<1>") end lower_transform_passes = copy(transform_passes_list) if CONCATS_TO_DUS[] - push!(transform_passes_list, "dynamic_update_to_concat<1>") + push!(transform_passes_list, "dynamic_update_to_concat<1>") end if recognize_comms @@ -1206,9 +1206,11 @@ function optimization_passes( ",", ) if CONCATS_TO_DUS[] - transform_passes = transform_passes * ",enzyme-hlo-generate-td{patterns=concat_to_onedim_dus},transform-interpreter,enzyme-hlo-remove-transform" + transform_passes = + transform_passes * + ",enzyme-hlo-generate-td{patterns=concat_to_onedim_dus},transform-interpreter,enzyme-hlo-remove-transform" if lower_comms - push!(lower_transform_passes, "concat_to_onedim_dus") + push!(lower_transform_passes, "concat_to_onedim_dus") end end func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index f39a3303aa..7e0c1bddff 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -406,7 +406,9 @@ end Base.similar(a::ConcretePJRTArray, dims::Dims) = similar(a, eltype(a), dims) -@inline function Base.similar(AT::Type{<:ConcretePJRTArray{T}}, dims::Dims; kwargs...) where {T} +@inline function Base.similar( + AT::Type{<:ConcretePJRTArray{T}}, dims::Dims; kwargs... +) where {T} return Base.similar(AT, T, dims; kwargs...) end diff --git a/test/indexing.jl b/test/indexing.jl index e8a54fac99..9a4aad4a3c 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -48,10 +48,9 @@ end end function write_row_simple!(xs_mut, i, v) - xs_mut[:, i] = v - nothing + xs_mut[:, i] = v + return nothing end - @testset "setindex: DUS" begin x_ra = similar(ConcreteRArray{Float64}, (76, 100)) y_ra = similar(ConcreteRArray{Float64}, 76)