From 8ec2a6f458ab7ed40356575858aa85b6d957daf7 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 30 May 2025 02:08:47 +0200 Subject: [PATCH 1/2] reinstate reactant's tests --- ext/FluxEnzymeExt/FluxEnzymeExt.jl | 18 ++++++++---------- test/runtests.jl | 6 +++--- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index b124a4a973..b31764e244 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -14,16 +14,14 @@ EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true ### gradient & withgradient -# We can't use Enzyme.make_zero! to reset Duplicated, as it complains about e.g. LayerNorm having immutable differentiable fields -# After https://github.com/EnzymeAD/Enzyme.jl/pull/1961 probably this can be `make_zero!(Ref(dup.dval))` -_make_zero!(model) = Functors.fmapstructure(_make_zero_inner!, model) -function _make_zero_inner!(x::AbstractArray{<:Number}) - Optimisers.isnumeric(x) || return - Optimisers.maywrite(x) || error("can't handle this") - fill!(x, zero(eltype(x))) - nothing -end -_make_zero_inner!(x) = nothing # any other Functors leaf type +# After https://github.com/EnzymeAD/Enzyme.jl/pull/1961 Enzyme.make_zero! can be used, +# but we have to use Ref as it complains about e.g. LayerNorm having immutable differentiable fields +_make_zero!(model) = Enzyme.make_zero!(Ref(model)) + +## OLD CODE +# _make_zero!(model) = Functors.fmapstructure(_make_zero_inner!, model) +# _make_zero_inner!(x::AbstractArray{<:Number}) = Enzyme.make_zero!(x) +# _make_zero_inner!(x) = nothing # any other Functors leaf type #= # This _make_zero! matches what Flux allows elsewhere: julia> Flux.setup(Adam(), (1:3.)') diff --git a/test/runtests.jl b/test/runtests.jl index 64118a2f56..65f456d7ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,11 +25,11 @@ using Zygote: Zygote # ENV["FLUX_TEST_AMDGPU"] = "true" # ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" -# ENV["FLUX_TEST_CPU"] = "false" +ENV["FLUX_TEST_CPU"] = "false" # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" -# ENV["FLUX_TEST_ENZYME"] = "false" -ENV["FLUX_TEST_REACTANT"] = "false" +ENV["FLUX_TEST_ENZYME"] = "false" +ENV["FLUX_TEST_REACTANT"] = "true" const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true" const FLUX_TEST_CPU = get(ENV, "FLUX_TEST_CPU", "true") == "true" From 514fc5685fe0a20ed81c5a7bfcf95c396fd333c9 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 30 May 2025 09:46:33 +0200 Subject: [PATCH 2/2] work --- test/Project.toml | 2 ++ test/ext_reactant/reactant.jl | 52 +++++++++++++++++------------------ test/runtests.jl | 1 + 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 34ec552afe..22d6f9aa2a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -15,6 +16,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl index f84651d6f2..494402db8a 100644 --- a/test/ext_reactant/reactant.jl +++ b/test/ext_reactant/reactant.jl @@ -10,37 +10,37 @@ end models_xs = [ (Dense(2=>4), randn(Float32, 2), "Dense"), - (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"), + # (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"), - (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), + # (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), - (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), + # (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), - # all arguments must have at least the same length of the firs one - # a = (Conv((3, 3), 2 => 3),) - # b = ((σ = nothing, weight = Float32[-0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815], bias = Float32[0.33333334, 0.33333334, 0.33333334], stride = nothing, pad = nothing, dilation = nothing, groups = nothing),) - # (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), + # # all arguments must have at least the same length of the firs one + # # a = (Conv((3, 3), 2 => 3),) + # # b = ((σ = nothing, weight = Float32[-0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815], bias = Float32[0.33333334, 0.33333334, 0.33333334], stride = nothing, pad = nothing, dilation = nothing, groups = nothing),) + # # (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), - # all arguments must have at least the same length of the firs one - # a = (Chain(Conv((3, 3), 2 => 3), Conv((3, 3), 3 => 1, tanh)),) - # b = ((layers = ((σ = nothing, weight = Float32[0.2703631 0.15815677 0.2918554; 0.20036785 0.43450722 0.3525422; 0.3541182 0.32077286 0.44091386;;; 0.3233156 0.08538988 0.25763267; 0.413441 0.66042584 0.16991; 0.36993486 0.5990643 0.10123589;;;; 0.45728725 0.500834 0.46808332; 0.3662355 0.35068494 0.27277413; 0.44974697 0.47245422 0.10595817;;; 0.36255562 0.6111583 0.52779496; 0.27237993 0.25857046 0.33643073; 0.6679214 0.066386 0.32072845;;;; -0.4879305 -0.59246373 -0.59834677; -0.55097836 -0.5006755 -0.4233263; -0.72177917 -0.65806544 -0.38224664;;; -0.4765812 -0.6856963 -0.5864509; -0.6547631 -0.55094117 -0.38632843; -0.74521375 -0.3817107 -0.48642716], bias = Float32[0.7159346, 0.7152501, -1.0509125], stride = nothing, pad = nothing, dilation = nothing, groups = nothing), (σ = nothing, weight = Float32[0.32858944 -0.10135343 -0.25303265; -0.13622479 0.023095237 0.1746222; 0.18829267 -0.5047879 0.07125988;;; 0.023820637 -0.06595295 -0.003393827; -0.111125976 0.0023178488 0.08700531; -0.073591515 0.057915907 0.048598815;;; 0.016056929 -0.5129501 -0.15588683; -0.3756476 -0.09993523 -0.45654622; -0.3688693 -0.33078116 -0.4093926;;;;], bias = Float32[0.77964276], stride = nothing, pad = nothing, dilation = nothing, groups = nothing)),),) - # (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), + # # all arguments must have at least the same length of the firs one + # # a = (Chain(Conv((3, 3), 2 => 3), Conv((3, 3), 3 => 1, tanh)),) + # # b = ((layers = ((σ = nothing, weight = Float32[0.2703631 0.15815677 0.2918554; 0.20036785 0.43450722 0.3525422; 0.3541182 0.32077286 0.44091386;;; 0.3233156 0.08538988 0.25763267; 0.413441 0.66042584 0.16991; 0.36993486 0.5990643 0.10123589;;;; 0.45728725 0.500834 0.46808332; 0.3662355 0.35068494 0.27277413; 0.44974697 0.47245422 0.10595817;;; 0.36255562 0.6111583 0.52779496; 0.27237993 0.25857046 0.33643073; 0.6679214 0.066386 0.32072845;;;; -0.4879305 -0.59246373 -0.59834677; -0.55097836 -0.5006755 -0.4233263; -0.72177917 -0.65806544 -0.38224664;;; -0.4765812 -0.6856963 -0.5864509; -0.6547631 -0.55094117 -0.38632843; -0.74521375 -0.3817107 -0.48642716], bias = Float32[0.7159346, 0.7152501, -1.0509125], stride = nothing, pad = nothing, dilation = nothing, groups = nothing), (σ = nothing, weight = Float32[0.32858944 -0.10135343 -0.25303265; -0.13622479 0.023095237 0.1746222; 0.18829267 -0.5047879 0.07125988;;; 0.023820637 -0.06595295 -0.003393827; -0.111125976 0.0023178488 0.08700531; -0.073591515 0.057915907 0.048598815;;; 0.016056929 -0.5129501 -0.15588683; -0.3756476 -0.09993523 -0.45654622; -0.3688693 -0.33078116 -0.4093926;;;;], bias = Float32[0.77964276], stride = nothing, pad = nothing, dilation = nothing, groups = nothing)),),) + # # (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), - # https://github.com/EnzymeAD/Enzyme-JAX/issues/221 - # (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), + # # https://github.com/EnzymeAD/Enzyme-JAX/issues/221 + # # (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), - (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), + # (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), - # error: 'stablehlo.multiply' op requires compatible types for all operands and results - # This requires an issue to be opened. - # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), + # # error: 'stablehlo.multiply' op requires compatible types for all operands and results + # # This requires an issue to be opened. + # # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), - (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), + # (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), - # error: inferred shape '[1, 3, 9, 9]' is incompatible with return type of operation 'tensor<1x3x5x5xf32>' - # (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), + # # error: inferred shape '[1, 3, 9, 9]' is incompatible with return type of operation 'tensor<1x3x5x5xf32>' + # # (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - # (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # Apparent correctness issue + # # (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # Apparent correctness issue ] for (model, x, name) in models_xs @@ -76,11 +76,11 @@ end end models_xs = [ - (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), + # (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), ] for (model, x, name) in models_xs diff --git a/test/runtests.jl b/test/runtests.jl index 65f456d7ca..3e731a5bc4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -200,6 +200,7 @@ end # │ CUDA.jl's JLLs were precompiled without an NVIDIA driver present. Pkg.add("Reactant") using Reactant: Reactant + Reactant.set_default_backend("cpu") @testset "Reactant" begin include("ext_reactant/test_utils_reactant.jl") include("ext_reactant/reactant.jl")