@@ -405,6 +405,52 @@ defmodule Axon do
405
405
@ doc """
406
406
Trainable Axon parameter used to create custom layers.
407
407
408
+ Parameters are specified in usages of `Axon.layer` and will be
409
+ automatically initialized and used in subsequent applications of
410
+ Axon models.
411
+
412
+ You must specify a parameter "template" which can be a static template
413
+ tensor or a function which takes model input templates and returns a
414
+ template. It's most common to use functions because most parameters'
415
+ shapes rely on input shape information.
416
+ """
417
+ @ doc type: :special
418
+ def parameter ( name , template , opts \\ [ ] )
419
+
420
+ def parameter ( name , % Nx.Tensor { } = template , opts ) do
421
+ opts = Keyword . validate! ( opts , initializer: :glorot_uniform , kind: :parameter )
422
+ initializer = validate_initializer! ( opts [ :initializer ] )
423
+ kind = opts [ :kind ] || :parameter
424
+
425
+ template = Nx . to_template ( template )
426
+
427
+ % Axon.Parameter {
428
+ name: name ,
429
+ template: template ,
430
+ initializer: initializer ,
431
+ kind: kind ,
432
+ # Legacy
433
+ type: Nx . type ( template ) ,
434
+ shape: Nx . shape ( template )
435
+ }
436
+ end
437
+
438
+ def parameter ( name , function , opts ) when is_function ( function ) do
439
+ opts = Keyword . validate! ( opts , initializer: :glorot_uniform , kind: :parameter )
440
+ initializer = validate_initializer! ( opts [ :initializer ] )
441
+ kind = opts [ :kind ] || :parameter
442
+
443
+ % Axon.Parameter {
444
+ name: name ,
445
+ template: function ,
446
+ initializer: initializer ,
447
+ kind: kind
448
+ }
449
+ end
450
+
451
+ @ doc """
452
+ Trainable Axon parameter used to create custom layers.
453
+
408
454
Parameters are specified in usages of `Axon.layer` and will
409
455
be automatically initialized and used in subsequent applications
410
456
of Axon models.
@@ -421,36 +467,35 @@ defmodule Axon do
421
467
@ doc type: :special
422
468
def param ( name , shape , opts \\ [ ] )
423
469
424
- def param ( name , { :map , [ _ | _ ] = inner_params } , opts ) do
425
- maybe_warn_on_param_opts ( opts )
470
+ def param ( name , shape , opts ) when is_binary ( name ) and is_tuple ( shape ) do
471
+ opts = Keyword . validate! ( opts , initializer: :glorot_uniform , type: { :f , 32 } , kind: :parameter )
472
+ { type , opts } = Keyword . pop ( opts , :type , { :f , 32 } )
426
473
427
- % Axon.Parameter {
428
- name: name ,
429
- type: :map ,
430
- children: inner_params
431
- }
474
+ template = Nx . template ( shape , type )
475
+ parameter ( name , template , opts )
432
476
end
433
477
434
- def param ( name , shape , opts ) when is_binary ( name ) and ( is_tuple ( shape ) or is_function ( shape ) ) do
478
+ def param ( name , shape , opts ) when is_binary ( name ) and is_function ( shape ) do
435
479
opts = Keyword . validate! ( opts , initializer: :glorot_uniform , type: { :f , 32 } , kind: :parameter )
436
- initializer = validate_initializer! ( opts [ :initializer ] )
437
- type = opts [ :type ] || { :f , 32 }
438
- kind = opts [ :kind ] || :parameter
480
+ { type , opts } = Keyword . pop ( opts , :type , { :f , 32 } )
439
481
440
- % Axon.Parameter {
441
- name: name ,
442
- shape: shape ,
443
- type: type ,
444
- initializer: initializer ,
445
- kind: kind
446
- }
482
+ { :arity , arity } = Function . info ( shape , :arity )
483
+
484
+ template =
485
+ shape_fun ( arity , fn templates ->
486
+ shapes = Enum . map ( List . wrap ( templates ) , & Nx . shape / 1 )
487
+ out_shape = apply ( shape , shapes )
488
+ Nx . template ( out_shape , type )
489
+ end )
490
+
491
+ parameter ( name , template , opts )
447
492
end
448
493
449
- defp maybe_warn_on_param_opts ( opts ) do
450
- if :initializer in opts or :type in opts do
451
- Logger . warning (
452
- "Passing options to a composite parameter has no effect. Pass them to inner parameters instead"
453
- )
494
+ for i <- 0 .. 128 do
495
+ args = Macro . generate_arguments ( i , __MODULE__ )
496
+
497
+ defp shape_fun ( unquote ( i ) , callback ) do
498
+ fn unquote_splicing ( args ) -> callback . ( unquote ( args ) ) end
454
499
end
455
500
end
456
501
@@ -2583,25 +2628,63 @@ defmodule Axon do
2583
2628
activation = opts [ :activation ]
2584
2629
gate = opts [ :gate ]
2585
2630
unroll = opts [ :unroll ]
2631
+
2586
2632
kernel_initializer = opts [ :kernel_initializer ]
2587
2633
2588
- input_kernel_shape = fn inp , _ , _ -> Axon.Shape . rnn_input_kernel ( inp , units , :lstm ) end
2589
- hidden_kernel_shape = fn inp , _ , _ -> Axon.Shape . rnn_hidden_kernel ( inp , units , :lstm ) end
2590
- bias_shape = fn inp , _ , _ -> Axon.Shape . rnn_bias ( inp , units , :lstm ) end
2634
+ input_kernel_template = fn inp , _ , _ ->
2635
+ shape = Axon.Shape . rnn_input_kernel ( Nx . shape ( inp ) , units , :lstm )
2636
+ Nx . template ( shape , :f32 )
2637
+ end
2591
2638
2592
- wii = param ( "wii" , input_kernel_shape , initializer: kernel_initializer )
2593
- wif = param ( "wif" , input_kernel_shape , initializer: kernel_initializer )
2594
- wig = param ( "wig" , input_kernel_shape , initializer: kernel_initializer )
2595
- wio = param ( "wio" , input_kernel_shape , initializer: kernel_initializer )
2639
+ hidden_kernel_template = fn inp , _ , _ ->
2640
+ shape = Axon.Shape . rnn_hidden_kernel ( Nx . shape ( inp ) , units , :lstm )
2641
+ Nx . template ( shape , :f32 )
2642
+ end
2643
+
2644
+ bias_template = fn inp , _ , _ ->
2645
+ shape = Axon.Shape . rnn_bias ( Nx . shape ( inp ) , units , :lstm )
2646
+ Nx . template ( shape , :f32 )
2647
+ end
2648
+
2649
+ initializer = fn prefix , init ->
2650
+ fn shape , type , key ->
2651
+ split_key = Nx.Random . split ( key , parts: 4 )
2652
+
2653
+ init =
2654
+ if is_atom ( init ) do
2655
+ apply ( Axon.Initializers , init , [ ] )
2656
+ else
2657
+ init
2658
+ end
2596
2659
2597
- whi = param ( "whi" , hidden_kernel_shape , initializer: kernel_initializer )
2598
- whf = param ( "whf" , hidden_kernel_shape , initializer: kernel_initializer )
2599
- whg = param ( "whg" , hidden_kernel_shape , initializer: kernel_initializer )
2600
- who = param ( "who" , hidden_kernel_shape , initializer: kernel_initializer )
2660
+ fun =
2661
+ case init do
2662
+ init when is_function ( init , 2 ) ->
2663
+ fn _ -> init . ( shape , type ) end
2664
+
2665
+ init when is_function ( init , 3 ) ->
2666
+ fn key -> init . ( shape , type , key ) end
2667
+ end
2668
+
2669
+ % {
2670
+ "#{ prefix } i" => fun . ( split_key [ 0 ] ) ,
2671
+ "#{ prefix } f" => fun . ( split_key [ 1 ] ) ,
2672
+ "#{ prefix } g" => fun . ( split_key [ 2 ] ) ,
2673
+ "#{ prefix } o" => fun . ( split_key [ 3 ] )
2674
+ }
2675
+ end
2676
+ end
2601
2677
2602
2678
# Parameters
2603
- input_kernel = param ( "input_kernel" , { :map , [ wii , wif , wig , wio ] } )
2604
- hidden_kernel = param ( "hidden_kernel" , { :map , [ whi , whf , whg , who ] } )
2679
+ input_kernel =
2680
+ parameter ( "input_kernel" , input_kernel_template ,
2681
+ initializer: initializer . ( "wi" , kernel_initializer )
2682
+ )
2683
+
2684
+ hidden_kernel =
2685
+ parameter ( "hidden_kernel" , hidden_kernel_template ,
2686
+ initializer: initializer . ( "wh" , kernel_initializer )
2687
+ )
2605
2688
2606
2689
hidden_state_name =
2607
2690
case opts [ :name ] do
@@ -2620,12 +2703,7 @@ defmodule Axon do
2620
2703
if opts [ :use_bias ] do
2621
2704
bias_initializer = opts [ :bias_initializer ]
2622
2705
2623
- bi = param ( "bi" , bias_shape , initializer: bias_initializer )
2624
- bf = param ( "bf" , bias_shape , initializer: bias_initializer )
2625
- bg = param ( "bg" , bias_shape , initializer: bias_initializer )
2626
- bo = param ( "bo" , bias_shape , initializer: bias_initializer )
2627
-
2628
- bias = param ( "bias" , { :map , [ bi , bf , bg , bo ] } )
2706
+ bias = parameter ( "bias" , bias_template , initializer: initializer . ( "b" , bias_initializer ) )
2629
2707
2630
2708
{ [ x , hidden_state , opts [ :mask ] , input_kernel , hidden_kernel , bias ] , :lstm }
2631
2709
else
@@ -2790,22 +2868,58 @@ defmodule Axon do
2790
2868
gate = opts [ :gate ]
2791
2869
unroll = opts [ :unroll ]
2792
2870
2793
- input_kernel_shape = fn inp , _ , _ -> Axon.Shape . rnn_input_kernel ( inp , units , :gru ) end
2794
- hidden_kernel_shape = fn inp , _ , _ -> Axon.Shape . rnn_hidden_kernel ( inp , units , :gru ) end
2795
- bias_shape = fn inp , _ , _ -> Axon.Shape . rnn_bias ( inp , units , :gru ) end
2871
+ input_kernel_template = fn inp , _ , _ ->
2872
+ shape = Axon.Shape . rnn_input_kernel ( Nx . shape ( inp ) , units , :gru )
2873
+ Nx . template ( shape , :f32 )
2874
+ end
2875
+
2876
+ hidden_kernel_template = fn inp , _ , _ ->
2877
+ shape = Axon.Shape . rnn_hidden_kernel ( Nx . shape ( inp ) , units , :gru )
2878
+ Nx . template ( shape , :f32 )
2879
+ end
2796
2880
2797
- kernel_initializer = opts [ :kernel_initializer ]
2881
+ bias_template = fn inp , _ , _ ->
2882
+ shape = Axon.Shape . rnn_bias ( Nx . shape ( inp ) , units , :gru )
2883
+ Nx . template ( shape , :f32 )
2884
+ end
2798
2885
2799
- wir = param ( "wir" , input_kernel_shape , initializer: kernel_initializer )
2800
- wiz = param ( "wiz" , input_kernel_shape , initializer: kernel_initializer )
2801
- win = param ( "win" , input_kernel_shape , initializer: kernel_initializer )
2886
+ initializer = fn prefix , init ->
2887
+ fn shape , type , key ->
2888
+ split_key = Nx.Random . split ( key , parts: 3 )
2802
2889
2803
- whr = param ( "whr" , hidden_kernel_shape , initializer: kernel_initializer )
2804
- whz = param ( "whz" , hidden_kernel_shape , initializer: kernel_initializer )
2805
- whn = param ( "whn" , hidden_kernel_shape , initializer: kernel_initializer )
2890
+ init =
2891
+ if is_atom ( init ) do
2892
+ apply ( Axon.Initializers , init , [ ] )
2893
+ else
2894
+ init
2895
+ end
2806
2896
2807
- input_kernel = param ( "input_kernel" , { :map , [ wir , wiz , win ] } )
2808
- hidden_kernel = param ( "hidden_kernel" , { :map , [ whr , whz , whn ] } )
2897
+ fun =
2898
+ case init do
2899
+ init when is_function ( init , 2 ) ->
2900
+ fn _ -> init . ( shape , type ) end
2901
+
2902
+ init when is_function ( init , 3 ) ->
2903
+ fn key -> init . ( shape , type , key ) end
2904
+ end
2905
+
2906
+ % {
2907
+ "#{ prefix } r" => fun . ( split_key [ 0 ] ) ,
2908
+ "#{ prefix } z" => fun . ( split_key [ 1 ] ) ,
2909
+ "#{ prefix } n" => fun . ( split_key [ 2 ] )
2910
+ }
2911
+ end
2912
+ end
2913
+
2914
+ input_kernel =
2915
+ parameter ( "input_kernel" , input_kernel_template ,
2916
+ initializer: initializer . ( "wi" , opts [ :kernel_initializer ] )
2917
+ )
2918
+
2919
+ hidden_kernel =
2920
+ parameter ( "hidden_kernel" , hidden_kernel_template ,
2921
+ initializer: initializer . ( "wh" , opts [ :kernel_initializer ] )
2922
+ )
2809
2923
2810
2924
hidden_state_name =
2811
2925
case opts [ :name ] do
@@ -2822,14 +2936,34 @@ defmodule Axon do
2822
2936
2823
2937
inputs =
2824
2938
if opts [ :use_bias ] do
2825
- bias_initializer = opts [ :bias_initializer ]
2939
+ bias_initializer = fn shape , type , key ->
2940
+ split_key = Nx.Random . split ( key , parts: 4 )
2941
+
2942
+ init =
2943
+ if is_atom ( opts [ :bias_initializer ] ) do
2944
+ apply ( Axon.Initializers , opts [ :bias_initializer ] , [ ] )
2945
+ else
2946
+ opts [ :bias_initializer ]
2947
+ end
2826
2948
2827
- br = param ( "br" , bias_shape , initializer: bias_initializer )
2828
- bz = param ( "bz" , bias_shape , initializer: bias_initializer )
2829
- bin = param ( "bin" , bias_shape , initializer: bias_initializer )
2830
- bhn = param ( "bhn" , bias_shape , initializer: bias_initializer )
2949
+ fun =
2950
+ case init do
2951
+ init when is_function ( init , 2 ) ->
2952
+ fn _ -> init . ( shape , type ) end
2953
+
2954
+ init when is_function ( init , 3 ) ->
2955
+ fn key -> init . ( shape , type , key ) end
2956
+ end
2957
+
2958
+ % {
2959
+ "br" => fun . ( split_key [ 0 ] ) ,
2960
+ "bz" => fun . ( split_key [ 1 ] ) ,
2961
+ "bin" => fun . ( split_key [ 2 ] ) ,
2962
+ "bhn" => fun . ( split_key [ 3 ] )
2963
+ }
2964
+ end
2831
2965
2832
- bias = param ( "bias" , { :map , [ br , bz , bin , bhn ] } )
2966
+ bias = parameter ( "bias" , bias_template , initializer: bias_initializer )
2833
2967
2834
2968
[ x , hidden_state , opts [ :mask ] , input_kernel , hidden_kernel , bias ]
2835
2969
else
@@ -2983,23 +3117,26 @@ defmodule Axon do
2983
3117
unroll = opts [ :unroll ]
2984
3118
kernel_initializer = opts [ :kernel_initializer ]
2985
3119
2986
- hidden_kernel_shape = fn _ , { inp , _ } , _ ->
2987
- shape = Tuple . delete_at ( inp , 1 )
2988
- Axon.Shape . conv_kernel ( shape , 4 * units , kernel_size , :first , 1 )
3120
+ hidden_kernel_template = fn _ , { inp , _ } , _ ->
3121
+ shape = Tuple . delete_at ( Nx . shape ( inp ) , 1 )
3122
+ shape = Axon.Shape . conv_kernel ( shape , 4 * units , kernel_size , :first , 1 )
3123
+ Nx . template ( shape , :f32 )
2989
3124
end
2990
3125
2991
- input_kernel_shape = fn inp , _ , _ ->
2992
- shape = Tuple . delete_at ( inp , 1 )
2993
- Axon.Shape . conv_kernel ( shape , 4 * units , kernel_size , :first , 1 )
3126
+ input_kernel_template = fn inp , _ , _ ->
3127
+ shape = Tuple . delete_at ( Nx . shape ( inp ) , 1 )
3128
+ shape = Axon.Shape . conv_kernel ( shape , 4 * units , kernel_size , :first , 1 )
3129
+ Nx . template ( shape , :f32 )
2994
3130
end
2995
3131
2996
- bias_shape = fn inp , _ , _ ->
2997
- shape = Tuple . delete_at ( inp , 1 )
2998
- Axon.Shape . conv_bias ( shape , 4 * units , kernel_size , :first , 1 )
3132
+ bias_template = fn inp , _ , _ ->
3133
+ shape = Tuple . delete_at ( Nx . shape ( inp ) , 1 )
3134
+ shape = Axon.Shape . conv_bias ( shape , 4 * units , kernel_size , :first , 1 )
3135
+ Nx . template ( shape , :f32 )
2999
3136
end
3000
3137
3001
- wi = param ( "input_kernel" , input_kernel_shape , initializer: kernel_initializer )
3002
- wh = param ( "hidden_kernel" , hidden_kernel_shape , initializer: kernel_initializer )
3138
+ wi = parameter ( "input_kernel" , input_kernel_template , initializer: kernel_initializer )
3139
+ wh = parameter ( "hidden_kernel" , hidden_kernel_template , initializer: kernel_initializer )
3003
3140
3004
3141
hidden_state_name =
3005
3142
case opts [ :name ] do
@@ -3017,7 +3154,7 @@ defmodule Axon do
3017
3154
{ inputs , op } =
3018
3155
if opts [ :use_bias ] do
3019
3156
bias_initializer = opts [ :bias_initializer ]
3020
- b = param ( "bias" , bias_shape , initializer: bias_initializer )
3157
+ b = parameter ( "bias" , bias_template , initializer: bias_initializer )
3021
3158
{ [ x , hidden_state , opts [ :mask ] , wi , wh , b ] , :conv_lstm }
3022
3159
else
3023
3160
{ [ x , hidden_state , opts [ :mask ] , wi , wh ] , :conv_lstm }
0 commit comments