Skip to content

Commit 8e0a6d9

Browse files
committed
Generalize quantization layers
1 parent a54ee13 commit 8e0a6d9

File tree

4 files changed

+41
-16
lines changed

4 files changed

+41
-16
lines changed

lib/axon.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3974,7 +3974,7 @@ defmodule Axon do
39743974
39753975
"""
39763976
@doc type: :debug
3977-
def trace_init(model, template, params \\ %{}, opts \\ []) do
3977+
def trace_init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) do
39783978
{init_fn, _} = build(model, opts)
39793979
Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, params)
39803980
end

lib/axon/model_state.ex

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,7 @@ defmodule Axon.ModelState do
171171
Returns an empty model state.
172172
"""
173173
def empty() do
174-
%Axon.ModelState{
175-
data: %{},
176-
parameters: %{},
177-
state: %{},
178-
frozen_parameters: %{}
179-
}
174+
new(%{})
180175
end
181176

182177
@doc """
@@ -190,12 +185,40 @@ defmodule Axon.ModelState do
190185
def new(data) when is_map(data) do
191186
%Axon.ModelState{
192187
data: data,
193-
parameters: get_paths(data),
188+
parameters: transform_to_parameters(data),
194189
state: %{},
195190
frozen_parameters: %{}
196191
}
197192
end
198193

194+
defp transform_to_parameters(%Nx.Tensor{}), do: nil
195+
196+
defp transform_to_parameters(map) when is_map(map) do
197+
map
198+
|> Enum.map(fn {k, v} -> {k, transform_to_parameters(v)} end)
199+
|> Enum.into(%{})
200+
end
201+
202+
defp transform_to_parameters(list) when is_list(list) do
203+
Enum.map(list, &transform_to_parameters/1)
204+
end
205+
206+
defp transform_to_parameters(value) do
207+
case value do
208+
map when is_map(map) ->
209+
keys = Map.keys(map)
210+
211+
if Enum.all?(keys, &(is_map(map[&1]) or match?(%Nx.Tensor{}, map[&1]))) do
212+
keys
213+
else
214+
transform_to_parameters(map)
215+
end
216+
217+
_ ->
218+
value
219+
end
220+
end
221+
199222
# Helpers
200223

201224
defp get_paths(map) do

lib/axon/quantization/layers.ex

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,20 @@ defmodule Axon.Quantization.Layers do
3535
bias,
3636
_opts
3737
) do
38-
x_shape = Nx.shape(x)
39-
last_dim = Nx.axis_size(x, -1)
38+
x_view = Nx.reshape(x, {:auto, Nx.axis_size(x, -1)})
4039

41-
x_view = Nx.reshape(x, {:auto, last_dim})
42-
43-
y = Nx.dot(x_view, Nx.as_type(Nx.transpose(w_int8), Nx.type(x)))
44-
y = Nx.multiply(y, scales)
45-
y = reshape_output(y, x_shape)
40+
y = Nx.dot(x_view, Nx.as_type(w_int8, Nx.type(x)))
41+
y = Nx.multiply(y, reshape_scales(scales, y))
42+
y = reshape_output(y, Nx.shape(x))
4643

4744
Nx.add(y, bias)
4845
end
4946

47+
deftransformp reshape_scales(scales, y) do
48+
ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1))
49+
Nx.reshape(scales, Tuple.append(ones, :auto))
50+
end
51+
5052
deftransformp reshape_output(output, x_shape) do
5153
all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1)
5254
new_shape = Tuple.append(all_but_last, :auto)

lib/axon/quantization/q_tensor.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ defmodule Axon.Quantization.QTensor do
5959
max: opts[:max]
6060
)
6161

62-
struct(__MODULE__, value: quantized_value, scale: scale, zero_point: zero_point)
62+
struct(__MODULE__, value: Nx.transpose(quantized_value), scale: scale, zero_point: zero_point)
6363
end
6464

6565
deftransformp quantize_affine(

0 commit comments

Comments
 (0)