Skip to content

Commit 8d7cd3a

Browse files
committed
feat: make MLUtils into a weakdep & suppport MLDataDevices
1 parent 904cac0 commit 8d7cd3a

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

lib/OptimizationOptimisers/Project.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
name = "OptimizationOptimisers"
22
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
7-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
87
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
98
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
109
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1110
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1211
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1312

13+
[extensions]
14+
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
15+
OptimizationOptimisersMLUtilsExt = "MLUtils"
16+
17+
[weakdeps]
18+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
19+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
20+
1421
[compat]
22+
MLDataDevices = "1.1"
1523
MLUtils = "0.4.4"
1624
Optimisers = "0.2, 0.3"
1725
Optimization = "4"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationOptimisersMLDataDevicesExt
2+
3+
using MLDataDevices
4+
using OptimizationOptimisers
5+
6+
OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true
7+
8+
end
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationOptimisersMLUtilsExt
2+
3+
using MLUtils
4+
using OptimizationOptimisers
5+
6+
OptimizationOptimisers.isa_dataiterator(::MLUtils.DataLoader) = true
7+
8+
end

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OptimizationOptimisers
22

33
using Reexport, Printf, ProgressLogging
44
@reexport using Optimisers, Optimization
5-
using Optimization.SciMLBase, MLUtils
5+
using Optimization.SciMLBase
66

77
SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
88
SciMLBase.requiresgradient(opt::AbstractRule) = true
@@ -16,6 +16,8 @@ function SciMLBase.__init(
1616
kwargs...)
1717
end
1818

19+
isa_dataiterator(data) = false
20+
1921
function SciMLBase.__solve(cache::OptimizationCache{
2022
F,
2123
RC,
@@ -57,13 +59,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
5759
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
5860
end
5961

60-
if cache.p isa MLUtils.DataLoader
62+
if isa_dataiterator(cache.p)
6163
data = cache.p
6264
dataiterate = true
6365
else
6466
data = [cache.p]
6567
dataiterate = false
6668
end
69+
6770
opt = cache.opt
6871
θ = copy(cache.u0)
6972
G = copy(θ)

0 commit comments

Comments
 (0)