Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit b292c86

Browse files
committed
fix: move LuxCore piracies over from Lux
1 parent 3b9469e commit b292c86

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxCore"
22
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "0.1.23"
4+
version = "0.1.24"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -11,22 +11,30 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1212

1313
[weakdeps]
14+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1415
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
15-
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1616
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
17+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
18+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
19+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1720

1821
[extensions]
22+
LuxCoreArrayInterfaceReverseDiffExt = ["ArrayInterface", "ReverseDiff"]
23+
LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"]
1924
LuxCoreChainRulesCoreExt = "ChainRulesCore"
20-
LuxCoreMLDataDevicesExt = "MLDataDevices"
2125
LuxCoreEnzymeCoreExt = "EnzymeCore"
26+
LuxCoreMLDataDevicesExt = "MLDataDevices"
2227

2328
[compat]
29+
ArrayInterface = "7.9"
2430
ChainRulesCore = "1.24"
2531
Compat = "4.15.0"
2632
DispatchDoctor = "0.4.10"
2733
EnzymeCore = "0.7.7"
2834
Functors = "0.4.12"
2935
MLDataDevices = "1"
3036
Random = "1.10"
37+
ReverseDiff = "1.15"
3138
Setfield = "1"
39+
Tracker = "0.2.34"
3240
julia = "1.10"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module LuxCoreArrayInterfaceReverseDiffExt
2+
3+
using ArrayInterface: ArrayInterface
4+
using LuxCore: LuxCore, AbstractExplicitLayer
5+
using ReverseDiff: TrackedReal, TrackedArray
6+
7+
# AoS to SoA conversion
8+
function LuxCore.apply(
9+
m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st)
10+
@warn "Lux.apply(m::AbstractExplicitLayer, \
11+
x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \
12+
Lux.apply(m::AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \
13+
st).\n\n\
14+
1. If this was not the desired behavior overload the dispatch on `m`.\n\n\
15+
2. This might have performance implications. Check which layer was causing this \
16+
problem using `Lux.Experimental.@debug_mode`." maxlog=1
17+
return LuxCore.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st)
18+
end
19+
20+
## Prevent an infinite loop
21+
LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
22+
23+
end
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module LuxCoreArrayInterfaceTrackerExt
2+
3+
using ArrayInterface: ArrayInterface
4+
using LuxCore: LuxCore, AbstractExplicitLayer
5+
using Tracker: TrackedReal, TrackedArray
6+
7+
# AoS to SoA conversion
8+
function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st)
9+
@warn "LuxCore.apply(m::AbstractExplicitLayer, \
10+
x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \
11+
LuxCore.apply(m::AbstractExplicitLayer, x::Tracker.TrackedArray}, ps, st).\n\n\
12+
1. If this was not the desired behavior overload the dispatch on `m`.\n\n\
13+
2. This might have performance implications. Check which layer was causing this \
14+
problem using `Lux.Experimental.@debug_mode`." maxlog=1
15+
return LuxCore.apply(m, ArrayInterface.aos_to_soa(x), ps, st)
16+
end
17+
18+
## Prevent an infinite loop
19+
LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
20+
21+
end

ext/LuxCoreChainRulesCoreExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
module LuxCoreChainRulesCoreExt
22

3-
using ChainRulesCore: @non_differentiable
4-
using LuxCore: LuxCore
3+
using ChainRulesCore: ChainRulesCore, @non_differentiable
4+
using LuxCore: LuxCore, AbstractExplicitLayer
55
using Random: AbstractRNG
66

77
@non_differentiable LuxCore.replicate(::AbstractRNG)
88

9+
function ChainRulesCore.rrule(::typeof(getproperty), m::AbstractExplicitLayer, x::Symbol)
10+
mₓ = getproperty(m, x)
11+
∇getproperty(_) = ntuple(Returns(ChainRulesCore.NoTangent()), 3)
12+
return mₓ, ∇getproperty
13+
end
14+
915
end

0 commit comments

Comments
 (0)