From b0f67c7d7447aac923f018b3a1533e9ec6a1a2f0 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Wed, 12 Jun 2024 09:52:30 -0400 Subject: [PATCH] Start weight tying --- lib/axon/compiler.ex | 57 ++++++++++++++++++------ lib/axon/model_state.ex | 34 ++++++++++++++ lib/axon/model_state/shared_parameter.ex | 19 ++++++++ test/axon/compiler_test.exs | 55 ++++++++++++++++++++++- 4 files changed, 149 insertions(+), 16 deletions(-) create mode 100644 lib/axon/model_state/shared_parameter.ex diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index ebd365436..dad4764b4 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -316,6 +316,10 @@ defmodule Axon.Compiler do end) end + defp merge_type(_, %Axon.ModelState.SharedParameter{}, value), do: value + + defp merge_type(_, _, %Axon.ModelState.SharedParameter{} = shared), do: shared + defp merge_type(key, template, value) do if Nx.type(template) != Nx.type(value) do Logger.warning( @@ -1061,20 +1065,7 @@ defmodule Axon.Compiler do # freezing and dtype policy parameter_inputs = Enum.map(layer_params, fn %{name: v, frozen: frz} -> - param = params[name][v] - - cond do - param != nil -> - safe_policy_cast(maybe_freeze(param, frz), policy, :compute) - - true -> - raise ArgumentError, - "parameter #{inspect(v)} for layer: #{inspect(name)} in" <> - " was not present in the given parameter map, this can" <> - " happen if you are using parameters intended for another" <> - " model or did not initialize portions of your model with" <> - " Axon.init/3" - end + resolve_parameter!(params, name, v, frz, policy) end) # Reorder the inputs according to the original input ordering @@ -1291,6 +1282,44 @@ defmodule Axon.Compiler do initializer.(shape, type, keys[layer_id][name]) end + defp resolve_parameter!(params, layer_name, param_name, freeze?, policy) do + layer_params = + case params[layer_name] do + nil -> + raise ArgumentError, "layer #{inspect(layer_name)} does not exist in the model state" + + %Axon.ModelState.SharedParameter{path: path} -> + get_in(params, path) + + map -> + map + end + + parameter = + case layer_params[param_name] do + nil -> + raise ArgumentError, + "parameter #{inspect(param_name)} for layer: #{inspect(layer_name)}" <> + " was not present in the given parameter map, this can" <> + " happen if you are using parameters intended for another" <> + " model or did not initialize portions of your model with" <> + " Axon.init/3" + + %Axon.ModelState.SharedParameter{path: path} -> + with nil <- get_in(params, path) do + raise ArgumentError, + "shared parameter for #{inspect(param_name)} in layer:" <> + " #{inspect(layer_name)}, references non-existent parameter" <> + " #{inspect(path)}" + end + + parameter -> + parameter + end + + safe_policy_cast(maybe_freeze(parameter, freeze?), policy, :compute) + end + defp maybe_freeze(param, true), do: Nx.Defn.Kernel.stop_grad(param) defp maybe_freeze(param, false), do: param diff --git a/lib/axon/model_state.ex b/lib/axon/model_state.ex index 8eede9a6f..494f4e66e 100644 --- a/lib/axon/model_state.ex +++ b/lib/axon/model_state.ex @@ -196,6 +196,33 @@ defmodule Axon.ModelState do } end + @doc """ + Ties parameters in the model state together. + + Tied parameters should be a map destination parameter to + source. For example, if you want the kernel of an embedding + layer to use the kernel of a dense layer as it's source, you + would do: + + Axon.ModelState.tie(model_state, ["embedding", "kernel"], ["dense", "kernel"]) + + You can tie individual parameters or entire layers together: + + Axon.ModelState.tie(model_state, ["embedding"], ["kernel"]) + """ + def tie(model_state, destination, source) do + update_in(model_state, [Access.key!(:data)], fn data -> + shared = Axon.ModelState.SharedParameter.new(source) + [key | rest] = Enum.reverse(destination) + + shared = Enum.reduce(rest, %{key => shared}, fn next, acc -> + %{next => acc} + end) + + tree_merge(shared, data, fn _, lhs, _ -> lhs end) + end) + end + # Helpers defp get_paths(map) do @@ -269,6 +296,10 @@ defmodule Axon.ModelState do nil -> Map.put(acc, key, val_lhs) + %Axon.ModelState.SharedParameter{} = val_rhs -> + new_val = fun.(key, val_lhs, val_rhs) + Map.put(acc, key, new_val) + %Nx.Tensor{} = val_rhs -> new_val = fun.(key, val_lhs, val_rhs) Map.put(acc, key, new_val) @@ -321,6 +352,9 @@ defmodule Axon.ModelState do {_, %Nx.Tensor{} = tensor}, {count, size} -> {count + Nx.size(tensor), size + Nx.byte_size(tensor)} + {_, %Axon.ModelState.SharedParameter{}}, {count, size} -> + {count, size} + {_, map}, {count, size} -> {inner_count, inner_size} = get_param_info(map) {count + inner_count, size + inner_size} diff --git a/lib/axon/model_state/shared_parameter.ex b/lib/axon/model_state/shared_parameter.ex new file mode 100644 index 000000000..6d505fc68 --- /dev/null +++ b/lib/axon/model_state/shared_parameter.ex @@ -0,0 +1,19 @@ +defmodule Axon.ModelState.SharedParameter do + @moduledoc false + + # Represents a tied or shared parameter for layers who's + # weights are connected but don't necessarily perform the + # same operation. This implements the Nx.Container behavior + # and contains an access path to the parameter that holds the + # original weight + + @derive { + Nx.Container, + keep: [:path], containers: [] + } + defstruct [:path] + + def new(path) do + %__MODULE__{path: path} + end +end diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 4b6026e84..5826fab78 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5465,15 +5465,24 @@ defmodule CompilerTest do end describe "edge cases" do - test "raises clean error on missing parameter" do + test "raises clean error on missing layer" do model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2) input = Nx.tensor([[1.0]]) - assert_raise ArgumentError, ~r/parameter "kernel" for layer:/, fn -> + assert_raise ArgumentError, ~r/layer \"dense_0\" does not exist/, fn -> Axon.predict(model, ModelState.empty(), input) end end + test "raises clean error on missing parameter" do + model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2) + input = Nx.tensor([[1.0]]) + + assert_raise ArgumentError, ~r/parameter \"kernel\" for layer:/, fn -> + Axon.predict(model, ModelState.new(%{"dense_0" => %{}}), input) + end + end + test "initializes a non-linear model" do x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2, name: "dense_0") y = Axon.input("input_1", shape: {nil, 1}) |> Axon.dense(2, name: "dense_1") @@ -5739,4 +5748,46 @@ defmodule CompilerTest do assert out =~ "bar:" end end + + describe "weight tying" do + test "initializes with shared parameters" do + model = + Axon.input("x") + |> Axon.embedding(32, 32, name: "embed") + |> Axon.dense(32, name: "dense") + + init_state = + ModelState.empty() + |> ModelState.tie(["embed", "kernel"], ["dense", "kernel"]) + + {init_fn, _} = Axon.build(model) + input = Nx.template({1, 4}, :u32) + assert %Axon.ModelState{data: %{"embed" => %{"kernel" => %Axon.ModelState.SharedParameter{}}}} = init_fn.(input, init_state) + end + + test "performs inference with weights tied after initialization" do + model = + Axon.input("x") + |> Axon.embedding(32, 32, name: "embed") + |> Axon.dense(32, name: "dense") + + {init_fn, predict_fn} = Axon.build(model) + + %Axon.ModelState{data: %{"dense" => %{"kernel" => k, "bias" => b}}} = + model_state = init_fn.(Nx.template({1, 4}, :u32), ModelState.empty()) + + model_state = + Axon.ModelState.tie(model_state, ["embed", "kernel"], ["dense", "kernel"]) + + input = Nx.tensor([[0, 1, 2, 3]]) + + actual_predict_fn = fn input, kernel, bias -> + input + |> Axon.Layers.embedding(kernel) + |> Axon.Layers.dense(kernel, bias) + end + + assert_equal(actual_predict_fn.(input, k, b), predict_fn.(model_state, input)) + end + end end