Skip to content

Commit a3840ae

Browse files
authored
fix: try fixing broken tests (#1279)
* fix: try fixing AMDGPU kaiming normal * docs: missing NNlib function * fix: incorrect package usage * fix: incorrect package usage * fix: add device to GenericBroadcastOp * chore: run the formatter * test: skip some tests * test: mark weight initializers tests as broken
1 parent 7e8f41c commit a3840ae

File tree

11 files changed

+26
-20
lines changed

11 files changed

+26
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.11.1"
4+
version = "1.11.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api/NN_Primitives/NNlib.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ ctc_loss
118118
```@docs
119119
logsumexp
120120
glu
121+
NNlib.@disallow_spawns
121122
```
122123

123124
!!! tip

ext/LuxReactantExt/patches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x))
1+
Utils.vec(x::AnyTracedRArray) = ReactantCore.materialize_traced_array(vec(x))
22

33
# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
44
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g

lib/LuxLib/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.7.1"
4+
version = "1.7.2"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LuxLibcuDNNExt
22

33
using LuxLib: LuxLib, Optional, ∂∅, Impl
4-
using LuxLib.Utils: safe_reshape, safe_vec, unsafe_known, recursive_unthunk
4+
using LuxLib.Utils: Utils, safe_reshape, safe_vec, unsafe_known, recursive_unthunk
55
using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray, DenseCuVector
66
using ChainRulesCore: ChainRulesCore
77
using cuDNN:

lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ function Impl.batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, args...)
1414

1515
y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, args...)
1616

17-
CUDA.unsafe_free!(γ)
18-
CUDA.unsafe_free!(β)
17+
Utils.unsafe_free!(γ)
18+
Utils.unsafe_free!(β)
1919

2020
return y, xμ, xσ⁻²
2121
end
@@ -136,10 +136,10 @@ function Impl.∇batchnorm_cudnn(
136136

137137
∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn(γ, β, x, ∂y, rμ, rσ², args...)
138138

139-
CUDA.unsafe_free!(γ)
140-
CUDA.unsafe_free!(β)
141-
CUDA.unsafe_free!(∂γ)
142-
CUDA.unsafe_free!(∂β)
139+
Utils.unsafe_free!(γ)
140+
Utils.unsafe_free!(β)
141+
Utils.unsafe_free!(∂γ)
142+
Utils.unsafe_free!(∂β)
143143

144144
return nothing, nothing, ∂x
145145
end

lib/LuxLib/src/impl/batched_mul.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function batched_matmul(
3131
end
3232

3333
function batched_matmul(
34-
opmode::GPUBroadcastOp{AMDGPUDevice},
34+
opmode::Union{GPUBroadcastOp{AMDGPUDevice},GenericBroadcastOp{AMDGPUDevice}},
3535
x::AbstractArray{<:Complex,3},
3636
y::AbstractArray{<:Complex,3},
3737
)

lib/LuxLib/src/traits.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ abstract type AbstractInternalArrayOpMode end
158158

159159
abstract type AbstractBroadcastOpMode <: AbstractInternalArrayOpMode end
160160

161-
struct GenericBroadcastOp <: AbstractBroadcastOpMode end
161+
struct GenericBroadcastOp{dev} <: AbstractBroadcastOpMode end
162162
struct GPUBroadcastOp{dev} <: AbstractBroadcastOpMode end
163163
struct LoopedArrayOp <: AbstractInternalArrayOpMode end
164164

@@ -192,15 +192,15 @@ Currently supported modes are:
192192
"""
193193
function internal_operation_mode(xs::Tuple)
194194
xs = filter(!isnothing, xs)
195-
known(Traits.use_generic_broadcasting(xs)) && return GenericBroadcastOp()
196-
197195
dev = get_device_type(xs)
196+
197+
known(Traits.use_generic_broadcasting(xs)) && return GenericBroadcastOp{dev}()
198198
dev <: AbstractGPUDevice && return GPUBroadcastOp{dev}()
199-
dev <: ReactantDevice && return GenericBroadcastOp()
199+
dev <: ReactantDevice && return GenericBroadcastOp{dev}()
200200

201201
# This check needs to be done after the GPU Check
202202
known(Utils.unrolled_any(!Traits.fast_scalar_indexing, xs)) &&
203-
return GenericBroadcastOp()
203+
return GenericBroadcastOp{dev}()
204204
return LoopedArrayOp()
205205
end
206206
internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,))

lib/LuxLib/test/others/forwarddiff_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@
154154

155155
test_jvp_computation(x -> op(x, kernel_size; stride, pad), x, u, ongpu)
156156

157+
# NNlib doesn't define ∇meanpool and ∇maxpool for AMDGPU properly
158+
mode == "amdgpu" && continue
159+
157160
test_jvp_computation(
158161
x ->
159162
only(Zygote.gradient(x -> sum(op(x, kernel_size; stride, pad)), x)),

lib/WeightInitializers/src/initializers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function kaiming_normal(
9999
) where {T<:Number}
100100
std = T(gain) / sqrt(T(first(Utils.nfan(dims...))))
101101
x = DeviceAgnostic.randn(rng, T, dims...)
102-
x .*= std
102+
broadcast!(Base.Fix2(*, std), x, x)
103103
return x
104104
end
105105

0 commit comments

Comments
 (0)