|
| 1 | +defmodule Axon.Quantization do |
| 2 | + @moduledoc """ |
| 3 | + Model quantization. |
| 4 | +
|
| 5 | + Model quantization is a technique for reducing the memory footprint of |
| 6 | + a model by converting portions of a model to use quantized representations. |
| 7 | + Typically, these quantized representations are low-precision integers. |
| 8 | +
|
| 9 | + This is an **experimental** API which implements weight-only quantization. |
| 10 | + The implementation in this module will convert dense layers in a large |
| 11 | + model to quantized-variants. The only supported quantization type is |
| 12 | + `{:s, 8}`. Axon quantization is inference-only. Training is not currently |
| 13 | + supported. |
| 14 | + """ |
| 15 | + alias Axon.Quantization.Layers |
| 16 | + alias Axon.Quantization.QTensor |
| 17 | + |
| 18 | + @doc """ |
| 19 | + Quantizes a model and a model state. |
| 20 | +
|
| 21 | + Given a model and model state, this method will rewrite all |
| 22 | + of the dense layers in the model to perform weight-only 8-bit |
| 23 | + integer versions of the same operation. It will also replace values |
| 24 | + for all dense kernels in the given model state with quantized |
| 25 | + tensors. |
| 26 | + """ |
| 27 | + def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do |
| 28 | + quantized_model = quantize_model(model) |
| 29 | + quantized_model_state = quantize_model_state(model, model_state) |
| 30 | + {quantized_model, quantized_model_state} |
| 31 | + end |
| 32 | + |
| 33 | + @doc """ |
| 34 | + Replaces standard operations with quantized variants. |
| 35 | +
|
| 36 | + The only supported conversion is to convert regular dense layers |
| 37 | + to a weight-only 8-bit integer variant. Note that this only replaces |
| 38 | + the properties of the model. If you have a pre-trained model state |
| 39 | + that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/1`. |
| 40 | +
|
| 41 | + All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`. |
| 42 | + """ |
| 43 | + def quantize_model(%Axon{} = model) do |
| 44 | + quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias -> |
| 45 | + weight_only_quantized_dense(x, units, use_bias: use_bias) |
| 46 | + end |
| 47 | + |
| 48 | + Axon.rewrite_nodes(model, fn |
| 49 | + %Axon.Node{op: :dense, meta: meta} -> |
| 50 | + &quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias]) |
| 51 | + |
| 52 | + _ -> |
| 53 | + :skip |
| 54 | + end) |
| 55 | + end |
| 56 | + |
| 57 | + @doc """ |
| 58 | + Returns a quantized model state. |
| 59 | +
|
| 60 | + Given a model and a model state, this function will replace |
| 61 | + all dense layer kernels with a quantized version of the weight. |
| 62 | +
|
| 63 | + Training is not currently supported, so all quantized layers are |
| 64 | + automatically frozen. |
| 65 | + """ |
| 66 | + def quantize_model_state(model, model_state) do |
| 67 | + dense_layer_names = |
| 68 | + model |
| 69 | + |> Axon.properties() |
| 70 | + |> Enum.filter(fn {_, v} -> v == :dense end) |
| 71 | + |> Enum.map(fn {k, _} -> k end) |
| 72 | + |> MapSet.new() |
| 73 | + |
| 74 | + state = |
| 75 | + Enum.reduce(dense_layer_names, model_state, fn layer_name, state -> |
| 76 | + update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1) |
| 77 | + end) |
| 78 | + |
| 79 | + Axon.ModelState.freeze(state, fn [name | _] -> |
| 80 | + MapSet.member?(dense_layer_names, name) |
| 81 | + end) |
| 82 | + end |
| 83 | + |
| 84 | + ## Layers |
| 85 | + |
| 86 | + @doc """ |
| 87 | + Adds a weight-only quantized dense layer to the network. |
| 88 | +
|
| 89 | + This is equivalent to a dense layer, but works on quantized |
| 90 | + weights for reducing model memory footprint. |
| 91 | +
|
| 92 | + Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/4`. |
| 93 | +
|
| 94 | + ## Options |
| 95 | +
|
| 96 | + * `:name` - layer name. |
| 97 | +
|
| 98 | + * `:kernel_initializer` - initializer for `kernel` weights. |
| 99 | + Defaults to `:glorot_uniform`. |
| 100 | +
|
| 101 | + * `:bias_initializer` - initializer for `bias` weights. Defaults |
| 102 | + to `:zeros`. |
| 103 | +
|
| 104 | + * `:use_bias` - whether the layer should add bias to the output. |
| 105 | + Defaults to `true`. |
| 106 | + """ |
| 107 | + def weight_only_quantized_dense(x, units, opts \\ []) do |
| 108 | + opts = |
| 109 | + Keyword.validate!(opts, [ |
| 110 | + :name, |
| 111 | + :meta, |
| 112 | + use_bias: true, |
| 113 | + kernel_initializer: :glorot_uniform, |
| 114 | + bias_initializer: :zeros |
| 115 | + ]) |
| 116 | + |
| 117 | + meta = |
| 118 | + opts[:meta] || |
| 119 | + %{} |
| 120 | + |> Map.put(:units, units) |
| 121 | + |> Map.put(:use_bias, opts[:use_bias]) |
| 122 | + |
| 123 | + kernel_shape = &Axon.Shape.dense_kernel(&1, units) |
| 124 | + bias_shape = &Axon.Shape.dense_bias(&1, units) |
| 125 | + |
| 126 | + kernel = |
| 127 | + Axon.param("kernel", kernel_shape, |
| 128 | + initializer: fn shape, type, key -> |
| 129 | + fun = |
| 130 | + case opts[:kernel_initializer] do |
| 131 | + init when is_atom(init) -> |
| 132 | + apply(Axon.Initializers, []) |
| 133 | + |
| 134 | + fun when is_function(fun) -> |
| 135 | + fun |
| 136 | + end |
| 137 | + |
| 138 | + tensor = |
| 139 | + case fun do |
| 140 | + fun when is_function(fun, 2) -> |
| 141 | + fun.(shape, type) |
| 142 | + |
| 143 | + fun when is_function(fun, 3) -> |
| 144 | + fun.(shape, type, key) |
| 145 | + end |
| 146 | + |
| 147 | + QTensor.from_tensor(tensor) |
| 148 | + end |
| 149 | + ) |
| 150 | + |
| 151 | + {inputs, op} = |
| 152 | + if opts[:use_bias] do |
| 153 | + bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) |
| 154 | + {[x, kernel, bias], &Layers.weight_only_quantized_dense/4} |
| 155 | + else |
| 156 | + {[x, kernel], &Layers.weight_only_quantized_dense/3} |
| 157 | + end |
| 158 | + |
| 159 | + Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense) |
| 160 | + end |
| 161 | +end |
0 commit comments