Skip to content

Commit 310a885

Browse files
committed
Copied style of Flux for params
1 parent b24434f commit 310a885

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/generic.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@ _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, apply(t, x), apply(t,
1414
printshifted(io::IO::Kernel,shift::Int) = print(io,"")
1515
Base.show(io::IO::Kernel) = print(io,nameof(typeof(κ)))
1616

17+
function params(k::Kernel)
18+
ps = []
19+
params!(ps,k)
20+
return ps
21+
end
22+
23+
function params!(ps,k::Kernel)
24+
for child in trainable(k)
25+
params!(ps,k)
26+
end
27+
end
28+
29+
params!(ps,x::AbstractArray) = push!(ps,x)
30+
31+
trainable(x) = ()
32+
1733
### Syntactic sugar for creating matrices and using kernel functions
1834
for k in subtypes(BaseKernel)
1935
@eval begin

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ base_transform(t::Transform) = eval(nameof(typeof(t)))
3838
For a kernel return a tuple with parameters of the transform followed by the specific parameters of the kernel
3939
For a transform return its parameters, for a `ChainTransform` return a vector of `params(t)`.
4040
"""
41-
params(k::Kernel) = (params(transform(k)),)
41+
params

0 commit comments

Comments
 (0)