Skip to content

Commit ee8f855

Browse files
authored
Add simple quantization API (#586)
* Quantization draft * Finish initial quantization API * Docs
1 parent 216fafe commit ee8f855

File tree

7 files changed

+528
-8
lines changed

7 files changed

+528
-8
lines changed

lib/axon.ex

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,12 @@ defmodule Axon do
855855
use_bias: true
856856
])
857857

858+
meta =
859+
opts[:meta] ||
860+
%{}
861+
|> Map.put(:units, units)
862+
|> Map.put(:use_bias, opts[:use_bias])
863+
858864
kernel_shape = &Axon.Shape.dense_kernel(&1, units)
859865
bias_shape = &Axon.Shape.dense_bias(&1, units)
860866

@@ -868,7 +874,7 @@ defmodule Axon do
868874
{[x, kernel], :dense}
869875
end
870876

871-
node = layer(op, inputs, name: opts[:name], meta: opts[:meta], op_name: :dense)
877+
node = layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)
872878

873879
if activation = opts[:activation] do
874880
activation(node, activation)
@@ -3666,7 +3672,7 @@ defmodule Axon do
36663672
"""
36673673
@doc type: :graph
36683674
def get_op_counts(%Axon{} = axon) do
3669-
reduce_nodes(axon, %{}, fn %Axon.Node{op: op}, op_counts ->
3675+
reduce_nodes(axon, %{}, fn %Axon.Node{op_name: op}, op_counts ->
36703676
Map.update(op_counts, op, 1, fn x -> x + 1 end)
36713677
end)
36723678
end
@@ -4096,6 +4102,33 @@ defmodule Axon do
40964102
end
40974103
end
40984104

4105+
@doc """
4106+
Returns a mapping of layer names to layer properties.
4107+
"""
4108+
def properties(%Axon{output: id, nodes: nodes}) do
4109+
{_, _, properties} = node_properties(id, nodes, {%{}, %{}, %{}})
4110+
properties
4111+
end
4112+
4113+
defp node_properties(id, nodes, {cache, op_counts, properties} = acc) do
4114+
case cache do
4115+
%{^id => _} ->
4116+
{cache, op_counts, properties}
4117+
4118+
%{} ->
4119+
%Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id]
4120+
4121+
{cache, op_counts, properties} =
4122+
Enum.reduce(parents, acc, &node_properties(&1, nodes, &2))
4123+
4124+
name = name_fn.(op_name, op_counts)
4125+
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end)
4126+
properties = Map.put(properties, name, op_name)
4127+
4128+
{Map.put(cache, id, name), op_counts, properties}
4129+
end
4130+
end
4131+
40994132
## Helpers
41004133

41014134
@valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++

lib/axon/model_state.ex

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ defmodule Axon.ModelState do
206206

207207
defp traverse(%Nx.Tensor{}, acc), do: [Enum.reverse(acc)]
208208

209+
defp traverse(%Axon.Quantization.QTensor{}, acc), do: [Enum.reverse(acc)]
210+
209211
defp traverse(map, acc) do
210212
Enum.flat_map(map, fn {k, value} ->
211213
traverse(value, [k | acc])
@@ -273,6 +275,10 @@ defmodule Axon.ModelState do
273275
new_val = fun.(key, val_lhs, val_rhs)
274276
Map.put(acc, key, new_val)
275277

278+
%Axon.Quantization.QTensor{} = val_rhs ->
279+
new_val = fun.(key, val_lhs, val_rhs)
280+
Map.put(acc, key, new_val)
281+
276282
val_rhs when is_map(val_lhs) and is_map(val_rhs) ->
277283
updated_val = tree_merge(val_lhs, val_rhs, fun)
278284
Map.put(acc, key, updated_val)
@@ -321,6 +327,11 @@ defmodule Axon.ModelState do
321327
{_, %Nx.Tensor{} = tensor}, {count, size} ->
322328
{count + Nx.size(tensor), size + Nx.byte_size(tensor)}
323329

330+
{_, %Axon.Quantization.QTensor{value: value, scale: scale, zero_point: zero}},
331+
{count, size} ->
332+
{count + Nx.size(value) + Nx.size(scale) + Nx.size(zero),
333+
size + Nx.byte_size(value) + Nx.byte_size(scale) + Nx.byte_size(zero)}
334+
324335
{_, map}, {count, size} ->
325336
{inner_count, inner_size} = get_param_info(map)
326337
{count + inner_count, size + inner_size}

lib/axon/quantization.ex

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
defmodule Axon.Quantization do
2+
@moduledoc """
3+
Model quantization.
4+
5+
Model quantization is a technique for reducing the memory footprint of
6+
a model by converting portions of a model to use quantized representations.
7+
Typically, these quantized representations are low-precision integers.
8+
9+
This is an **experimental** API which implements weight-only quantization.
10+
The implementation in this module will convert dense layers in a large
11+
model to quantized-variants. The only supported quantization type is
12+
`{:s, 8}`. Axon quantization is inference-only. Training is not currently
13+
supported.
14+
"""
15+
alias Axon.Quantization.Layers
16+
alias Axon.Quantization.QTensor
17+
18+
@doc """
19+
Quantizes a model and a model state.
20+
21+
Given a model and model state, this method will rewrite all
22+
of the dense layers in the model to perform weight-only 8-bit
23+
integer versions of the same operation. It will also replace values
24+
for all dense kernels in the given model state with quantized
25+
tensors.
26+
"""
27+
def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do
28+
quantized_model = quantize_model(model)
29+
quantized_model_state = quantize_model_state(model, model_state)
30+
{quantized_model, quantized_model_state}
31+
end
32+
33+
@doc """
34+
Replaces standard operations with quantized variants.
35+
36+
The only supported conversion is to convert regular dense layers
37+
to a weight-only 8-bit integer variant. Note that this only replaces
38+
the properties of the model. If you have a pre-trained model state
39+
that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/1`.
40+
41+
All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`.
42+
"""
43+
def quantize_model(%Axon{} = model) do
44+
quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias ->
45+
weight_only_quantized_dense(x, units, use_bias: use_bias)
46+
end
47+
48+
Axon.rewrite_nodes(model, fn
49+
%Axon.Node{op: :dense, meta: meta} ->
50+
&quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias])
51+
52+
_ ->
53+
:skip
54+
end)
55+
end
56+
57+
@doc """
58+
Returns a quantized model state.
59+
60+
Given a model and a model state, this function will replace
61+
all dense layer kernels with a quantized version of the weight.
62+
63+
Training is not currently supported, so all quantized layers are
64+
automatically frozen.
65+
"""
66+
def quantize_model_state(model, model_state) do
67+
dense_layer_names =
68+
model
69+
|> Axon.properties()
70+
|> Enum.filter(fn {_, v} -> v == :dense end)
71+
|> Enum.map(fn {k, _} -> k end)
72+
|> MapSet.new()
73+
74+
state =
75+
Enum.reduce(dense_layer_names, model_state, fn layer_name, state ->
76+
update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1)
77+
end)
78+
79+
Axon.ModelState.freeze(state, fn [name | _] ->
80+
MapSet.member?(dense_layer_names, name)
81+
end)
82+
end
83+
84+
## Layers
85+
86+
@doc """
87+
Adds a weight-only quantized dense layer to the network.
88+
89+
This is equivalent to a dense layer, but works on quantized
90+
weights for reducing model memory footprint.
91+
92+
Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/4`.
93+
94+
## Options
95+
96+
* `:name` - layer name.
97+
98+
* `:kernel_initializer` - initializer for `kernel` weights.
99+
Defaults to `:glorot_uniform`.
100+
101+
* `:bias_initializer` - initializer for `bias` weights. Defaults
102+
to `:zeros`.
103+
104+
* `:use_bias` - whether the layer should add bias to the output.
105+
Defaults to `true`.
106+
"""
107+
def weight_only_quantized_dense(x, units, opts \\ []) do
108+
opts =
109+
Keyword.validate!(opts, [
110+
:name,
111+
:meta,
112+
use_bias: true,
113+
kernel_initializer: :glorot_uniform,
114+
bias_initializer: :zeros
115+
])
116+
117+
meta =
118+
opts[:meta] ||
119+
%{}
120+
|> Map.put(:units, units)
121+
|> Map.put(:use_bias, opts[:use_bias])
122+
123+
kernel_shape = &Axon.Shape.dense_kernel(&1, units)
124+
bias_shape = &Axon.Shape.dense_bias(&1, units)
125+
126+
kernel =
127+
Axon.param("kernel", kernel_shape,
128+
initializer: fn shape, type, key ->
129+
fun =
130+
case opts[:kernel_initializer] do
131+
init when is_atom(init) ->
132+
apply(Axon.Initializers, [])
133+
134+
fun when is_function(fun) ->
135+
fun
136+
end
137+
138+
tensor =
139+
case fun do
140+
fun when is_function(fun, 2) ->
141+
fun.(shape, type)
142+
143+
fun when is_function(fun, 3) ->
144+
fun.(shape, type, key)
145+
end
146+
147+
QTensor.from_tensor(tensor)
148+
end
149+
)
150+
151+
{inputs, op} =
152+
if opts[:use_bias] do
153+
bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer])
154+
{[x, kernel, bias], &Layers.weight_only_quantized_dense/4}
155+
else
156+
{[x, kernel], &Layers.weight_only_quantized_dense/3}
157+
end
158+
159+
Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)
160+
end
161+
end

lib/axon/quantization/layers.ex

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
defmodule Axon.Quantization.Layers do
2+
@moduledoc """
3+
Quantized Layer Implementations.
4+
"""
5+
alias Axon.Quantization.QTensor
6+
import Nx.Defn
7+
8+
@doc """
9+
Weight-only quantized version of a dense layer.
10+
11+
It expects the input kernel to be an `Axon.Quantization.QTensor`.
12+
"""
13+
deftransform weight_only_quantized_dense(input, kernel, bias \\ 0, opts \\ []) do
14+
{bias, opts} =
15+
case bias do
16+
%Nx.Tensor{} = bias ->
17+
{bias, opts}
18+
19+
bias when is_number(bias) ->
20+
{bias, opts}
21+
22+
opts when is_list(opts) ->
23+
{Nx.tensor(0), opts}
24+
25+
other ->
26+
raise ArgumentError, "invalid bias, expected a tensor, got #{inspect(other)}"
27+
end
28+
29+
weight_only_quantized_dense_impl(input, kernel, bias, opts)
30+
end
31+
32+
defnp weight_only_quantized_dense_impl(
33+
input,
34+
%QTensor{value: kernel, scale: scale},
35+
bias,
36+
_opts
37+
) do
38+
input
39+
|> Nx.dot([Nx.rank(input) - 1], Nx.as_type(kernel, Nx.type(input)), [0])
40+
|> Nx.multiply(scale)
41+
|> Nx.add(bias)
42+
end
43+
end

0 commit comments

Comments
 (0)