Skip to content

Commit c4d33e5

Browse files
authored
Use templates as parameters (#588)
* Use templates as parameters * Uncomment deps
1 parent 8cee5a9 commit c4d33e5

File tree

6 files changed

+297
-223
lines changed

6 files changed

+297
-223
lines changed

lib/axon.ex

Lines changed: 209 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,52 @@ defmodule Axon do
405405
@doc """
406406
Trainable Axon parameter used to create custom layers.
407407
408+
Parameters are specified in usages of `Axon.layer` and will be
409+
automatically initialized and used in subsequent applications of
410+
Axon models.
411+
412+
You must specify a parameter "template" which can be a static template
413+
tensor or a function which takes model input templates and returns a
414+
template. It's most common to use functions because most parameters'
415+
shapes rely on input shape information.
416+
"""
417+
@doc type: :special
418+
def parameter(name, template, opts \\ [])
419+
420+
def parameter(name, %Nx.Tensor{} = template, opts) do
421+
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
422+
initializer = validate_initializer!(opts[:initializer])
423+
kind = opts[:kind] || :parameter
424+
425+
template = Nx.to_template(template)
426+
427+
%Axon.Parameter{
428+
name: name,
429+
template: template,
430+
initializer: initializer,
431+
kind: kind,
432+
# Legacy
433+
type: Nx.type(template),
434+
shape: Nx.shape(template)
435+
}
436+
end
437+
438+
def parameter(name, function, opts) when is_function(function) do
439+
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
440+
initializer = validate_initializer!(opts[:initializer])
441+
kind = opts[:kind] || :parameter
442+
443+
%Axon.Parameter{
444+
name: name,
445+
template: function,
446+
initializer: initializer,
447+
kind: kind
448+
}
449+
end
450+
451+
@doc """
452+
Trainable Axon parameter used to create custom layers.
453+
408454
Parameters are specified in usages of `Axon.layer` and will
409455
be automatically initialized and used in subsequent applications
410456
of Axon models.
@@ -421,36 +467,35 @@ defmodule Axon do
421467
@doc type: :special
422468
def param(name, shape, opts \\ [])
423469

424-
def param(name, {:map, [_ | _] = inner_params}, opts) do
425-
maybe_warn_on_param_opts(opts)
470+
def param(name, shape, opts) when is_binary(name) and is_tuple(shape) do
471+
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
472+
{type, opts} = Keyword.pop(opts, :type, {:f, 32})
426473

427-
%Axon.Parameter{
428-
name: name,
429-
type: :map,
430-
children: inner_params
431-
}
474+
template = Nx.template(shape, type)
475+
parameter(name, template, opts)
432476
end
433477

434-
def param(name, shape, opts) when is_binary(name) and (is_tuple(shape) or is_function(shape)) do
478+
def param(name, shape, opts) when is_binary(name) and is_function(shape) do
435479
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
436-
initializer = validate_initializer!(opts[:initializer])
437-
type = opts[:type] || {:f, 32}
438-
kind = opts[:kind] || :parameter
480+
{type, opts} = Keyword.pop(opts, :type, {:f, 32})
439481

440-
%Axon.Parameter{
441-
name: name,
442-
shape: shape,
443-
type: type,
444-
initializer: initializer,
445-
kind: kind
446-
}
482+
{:arity, arity} = Function.info(shape, :arity)
483+
484+
template =
485+
shape_fun(arity, fn templates ->
486+
shapes = Enum.map(List.wrap(templates), &Nx.shape/1)
487+
out_shape = apply(shape, shapes)
488+
Nx.template(out_shape, type)
489+
end)
490+
491+
parameter(name, template, opts)
447492
end
448493

449-
defp maybe_warn_on_param_opts(opts) do
450-
if :initializer in opts or :type in opts do
451-
Logger.warning(
452-
"Passing options to a composite parameter has no effect. Pass them to inner parameters instead"
453-
)
494+
for i <- 0..128 do
495+
args = Macro.generate_arguments(i, __MODULE__)
496+
497+
defp shape_fun(unquote(i), callback) do
498+
fn unquote_splicing(args) -> callback.(unquote(args)) end
454499
end
455500
end
456501

@@ -2583,25 +2628,63 @@ defmodule Axon do
25832628
activation = opts[:activation]
25842629
gate = opts[:gate]
25852630
unroll = opts[:unroll]
2631+
25862632
kernel_initializer = opts[:kernel_initializer]
25872633

2588-
input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :lstm) end
2589-
hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :lstm) end
2590-
bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :lstm) end
2634+
input_kernel_template = fn inp, _, _ ->
2635+
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :lstm)
2636+
Nx.template(shape, :f32)
2637+
end
25912638

2592-
wii = param("wii", input_kernel_shape, initializer: kernel_initializer)
2593-
wif = param("wif", input_kernel_shape, initializer: kernel_initializer)
2594-
wig = param("wig", input_kernel_shape, initializer: kernel_initializer)
2595-
wio = param("wio", input_kernel_shape, initializer: kernel_initializer)
2639+
hidden_kernel_template = fn inp, _, _ ->
2640+
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :lstm)
2641+
Nx.template(shape, :f32)
2642+
end
2643+
2644+
bias_template = fn inp, _, _ ->
2645+
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :lstm)
2646+
Nx.template(shape, :f32)
2647+
end
2648+
2649+
initializer = fn prefix, init ->
2650+
fn shape, type, key ->
2651+
split_key = Nx.Random.split(key, parts: 4)
2652+
2653+
init =
2654+
if is_atom(init) do
2655+
apply(Axon.Initializers, init, [])
2656+
else
2657+
init
2658+
end
25962659

2597-
whi = param("whi", hidden_kernel_shape, initializer: kernel_initializer)
2598-
whf = param("whf", hidden_kernel_shape, initializer: kernel_initializer)
2599-
whg = param("whg", hidden_kernel_shape, initializer: kernel_initializer)
2600-
who = param("who", hidden_kernel_shape, initializer: kernel_initializer)
2660+
fun =
2661+
case init do
2662+
init when is_function(init, 2) ->
2663+
fn _ -> init.(shape, type) end
2664+
2665+
init when is_function(init, 3) ->
2666+
fn key -> init.(shape, type, key) end
2667+
end
2668+
2669+
%{
2670+
"#{prefix}i" => fun.(split_key[0]),
2671+
"#{prefix}f" => fun.(split_key[1]),
2672+
"#{prefix}g" => fun.(split_key[2]),
2673+
"#{prefix}o" => fun.(split_key[3])
2674+
}
2675+
end
2676+
end
26012677

26022678
# Parameters
2603-
input_kernel = param("input_kernel", {:map, [wii, wif, wig, wio]})
2604-
hidden_kernel = param("hidden_kernel", {:map, [whi, whf, whg, who]})
2679+
input_kernel =
2680+
parameter("input_kernel", input_kernel_template,
2681+
initializer: initializer.("wi", kernel_initializer)
2682+
)
2683+
2684+
hidden_kernel =
2685+
parameter("hidden_kernel", hidden_kernel_template,
2686+
initializer: initializer.("wh", kernel_initializer)
2687+
)
26052688

26062689
hidden_state_name =
26072690
case opts[:name] do
@@ -2620,12 +2703,7 @@ defmodule Axon do
26202703
if opts[:use_bias] do
26212704
bias_initializer = opts[:bias_initializer]
26222705

2623-
bi = param("bi", bias_shape, initializer: bias_initializer)
2624-
bf = param("bf", bias_shape, initializer: bias_initializer)
2625-
bg = param("bg", bias_shape, initializer: bias_initializer)
2626-
bo = param("bo", bias_shape, initializer: bias_initializer)
2627-
2628-
bias = param("bias", {:map, [bi, bf, bg, bo]})
2706+
bias = parameter("bias", bias_template, initializer: initializer.("b", bias_initializer))
26292707

26302708
{[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias], :lstm}
26312709
else
@@ -2790,22 +2868,58 @@ defmodule Axon do
27902868
gate = opts[:gate]
27912869
unroll = opts[:unroll]
27922870

2793-
input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :gru) end
2794-
hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :gru) end
2795-
bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :gru) end
2871+
input_kernel_template = fn inp, _, _ ->
2872+
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :gru)
2873+
Nx.template(shape, :f32)
2874+
end
2875+
2876+
hidden_kernel_template = fn inp, _, _ ->
2877+
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :gru)
2878+
Nx.template(shape, :f32)
2879+
end
27962880

2797-
kernel_initializer = opts[:kernel_initializer]
2881+
bias_template = fn inp, _, _ ->
2882+
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :gru)
2883+
Nx.template(shape, :f32)
2884+
end
27982885

2799-
wir = param("wir", input_kernel_shape, initializer: kernel_initializer)
2800-
wiz = param("wiz", input_kernel_shape, initializer: kernel_initializer)
2801-
win = param("win", input_kernel_shape, initializer: kernel_initializer)
2886+
initializer = fn prefix, init ->
2887+
fn shape, type, key ->
2888+
split_key = Nx.Random.split(key, parts: 3)
28022889

2803-
whr = param("whr", hidden_kernel_shape, initializer: kernel_initializer)
2804-
whz = param("whz", hidden_kernel_shape, initializer: kernel_initializer)
2805-
whn = param("whn", hidden_kernel_shape, initializer: kernel_initializer)
2890+
init =
2891+
if is_atom(init) do
2892+
apply(Axon.Initializers, init, [])
2893+
else
2894+
init
2895+
end
28062896

2807-
input_kernel = param("input_kernel", {:map, [wir, wiz, win]})
2808-
hidden_kernel = param("hidden_kernel", {:map, [whr, whz, whn]})
2897+
fun =
2898+
case init do
2899+
init when is_function(init, 2) ->
2900+
fn _ -> init.(shape, type) end
2901+
2902+
init when is_function(init, 3) ->
2903+
fn key -> init.(shape, type, key) end
2904+
end
2905+
2906+
%{
2907+
"#{prefix}r" => fun.(split_key[0]),
2908+
"#{prefix}z" => fun.(split_key[1]),
2909+
"#{prefix}n" => fun.(split_key[2])
2910+
}
2911+
end
2912+
end
2913+
2914+
input_kernel =
2915+
parameter("input_kernel", input_kernel_template,
2916+
initializer: initializer.("wi", opts[:kernel_initializer])
2917+
)
2918+
2919+
hidden_kernel =
2920+
parameter("hidden_kernel", hidden_kernel_template,
2921+
initializer: initializer.("wh", opts[:kernel_initializer])
2922+
)
28092923

28102924
hidden_state_name =
28112925
case opts[:name] do
@@ -2822,14 +2936,34 @@ defmodule Axon do
28222936

28232937
inputs =
28242938
if opts[:use_bias] do
2825-
bias_initializer = opts[:bias_initializer]
2939+
bias_initializer = fn shape, type, key ->
2940+
split_key = Nx.Random.split(key, parts: 4)
2941+
2942+
init =
2943+
if is_atom(opts[:bias_initializer]) do
2944+
apply(Axon.Initializers, opts[:bias_initializer], [])
2945+
else
2946+
opts[:bias_initializer]
2947+
end
28262948

2827-
br = param("br", bias_shape, initializer: bias_initializer)
2828-
bz = param("bz", bias_shape, initializer: bias_initializer)
2829-
bin = param("bin", bias_shape, initializer: bias_initializer)
2830-
bhn = param("bhn", bias_shape, initializer: bias_initializer)
2949+
fun =
2950+
case init do
2951+
init when is_function(init, 2) ->
2952+
fn _ -> init.(shape, type) end
2953+
2954+
init when is_function(init, 3) ->
2955+
fn key -> init.(shape, type, key) end
2956+
end
2957+
2958+
%{
2959+
"br" => fun.(split_key[0]),
2960+
"bz" => fun.(split_key[1]),
2961+
"bin" => fun.(split_key[2]),
2962+
"bhn" => fun.(split_key[3])
2963+
}
2964+
end
28312965

2832-
bias = param("bias", {:map, [br, bz, bin, bhn]})
2966+
bias = parameter("bias", bias_template, initializer: bias_initializer)
28332967

28342968
[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias]
28352969
else
@@ -2983,23 +3117,26 @@ defmodule Axon do
29833117
unroll = opts[:unroll]
29843118
kernel_initializer = opts[:kernel_initializer]
29853119

2986-
hidden_kernel_shape = fn _, {inp, _}, _ ->
2987-
shape = Tuple.delete_at(inp, 1)
2988-
Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
3120+
hidden_kernel_template = fn _, {inp, _}, _ ->
3121+
shape = Tuple.delete_at(Nx.shape(inp), 1)
3122+
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
3123+
Nx.template(shape, :f32)
29893124
end
29903125

2991-
input_kernel_shape = fn inp, _, _ ->
2992-
shape = Tuple.delete_at(inp, 1)
2993-
Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
3126+
input_kernel_template = fn inp, _, _ ->
3127+
shape = Tuple.delete_at(Nx.shape(inp), 1)
3128+
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
3129+
Nx.template(shape, :f32)
29943130
end
29953131

2996-
bias_shape = fn inp, _, _ ->
2997-
shape = Tuple.delete_at(inp, 1)
2998-
Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
3132+
bias_template = fn inp, _, _ ->
3133+
shape = Tuple.delete_at(Nx.shape(inp), 1)
3134+
shape = Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
3135+
Nx.template(shape, :f32)
29993136
end
30003137

3001-
wi = param("input_kernel", input_kernel_shape, initializer: kernel_initializer)
3002-
wh = param("hidden_kernel", hidden_kernel_shape, initializer: kernel_initializer)
3138+
wi = parameter("input_kernel", input_kernel_template, initializer: kernel_initializer)
3139+
wh = parameter("hidden_kernel", hidden_kernel_template, initializer: kernel_initializer)
30033140

30043141
hidden_state_name =
30053142
case opts[:name] do
@@ -3017,7 +3154,7 @@ defmodule Axon do
30173154
{inputs, op} =
30183155
if opts[:use_bias] do
30193156
bias_initializer = opts[:bias_initializer]
3020-
b = param("bias", bias_shape, initializer: bias_initializer)
3157+
b = parameter("bias", bias_template, initializer: bias_initializer)
30213158
{[x, hidden_state, opts[:mask], wi, wh, b], :conv_lstm}
30223159
else
30233160
{[x, hidden_state, opts[:mask], wi, wh], :conv_lstm}

0 commit comments

Comments
 (0)