Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 77862ce

Browse files
committed
feat: add warning on attempting to move architecture
1 parent ad7356c commit 77862ce

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxCore"
22
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "0.1.21"
4+
version = "0.1.22"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -12,10 +12,12 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1212

1313
[weakdeps]
1414
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
15+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1516
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1617

1718
[extensions]
1819
LuxCoreChainRulesCoreExt = "ChainRulesCore"
20+
LuxCoreMLDataDevicesExt = "MLDataDevices"
1921
LuxCoreEnzymeCoreExt = "EnzymeCore"
2022

2123
[compat]
@@ -26,6 +28,7 @@ DispatchDoctor = "0.4.10"
2628
EnzymeCore = "0.7.7"
2729
ExplicitImports = "1.9.0"
2830
Functors = "0.4.8"
31+
MLDataDevices = "1"
2932
Optimisers = "0.3"
3033
Random = "1.10"
3134
Setfield = "1"
@@ -36,9 +39,10 @@ julia = "1.10"
3639
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3740
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3841
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
42+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
3943
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
4044
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4145
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4246

4347
[targets]
44-
test = ["Aqua", "EnzymeCore", "ExplicitImports", "Optimisers", "Random", "Test"]
48+
test = ["Aqua", "EnzymeCore", "ExplicitImports", "MLDataDevices", "Optimisers", "Random", "Test"]

ext/LuxCoreMLDataDevicesExt.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module LuxCoreMLDataDevicesExt
2+
3+
using LuxCore: LuxCore
4+
using MLDataDevices: MLDataDevices
5+
6+
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
7+
ldev = Symbol(dev, :Device)
8+
@eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractExplicitLayer)
9+
@warn "Lux layers are stateless and hence don't participate in device transfers. \
10+
Apply this function on the parameters and states generated using \
11+
`LuxCore.setup`."
12+
return NN
13+
end
14+
end
15+
16+
end

test/runtests.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, EnzymeCore
1+
using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, EnzymeCore,
2+
MLDataDevices
23

34
rng = LuxCore._default_rng()
45

@@ -290,4 +291,14 @@ end
290291
@test_throws ArgumentError BatchDuplicatedNoNeed(d, (d, d))
291292
@test Const(d) isa Const
292293
end
294+
295+
@testset "Device Transfer Warnings" begin
296+
my_layer = Dense(2, 2)
297+
298+
dev = cpu_device()
299+
@test_logs (
300+
:warn, "Lux layers are stateless and hence don't participate in device \
301+
transfers. Apply this function on the parameters and states generated \
302+
using `LuxCore.setup`.") dev(my_layer)
303+
end
293304
end

0 commit comments

Comments
 (0)