Skip to content

Commit ce2e247

Browse files
authored
Fix: call apply/3 as intended (#598)
* Fix: call apply/3 as intended * Add tests for Axon.Quantizaiton.weight_only_quantized_dense
1 parent 4cc474b commit ce2e247

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

lib/axon/quantization.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ defmodule Axon.Quantization do
132132
fun =
133133
case opts[:kernel_initializer] do
134134
init when is_atom(init) ->
135-
apply(Axon.Initializers, [])
135+
apply(Axon.Initializers, init, [])
136136

137137
fun when is_function(fun) ->
138138
fun

test/axon/quantization_test.exs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,18 @@ defmodule Axon.QuantizationTest do
4242
assert_equal(predict_fn.(quantized_model_state, inp), real_fn.(quantized_model_state, inp))
4343
end
4444
end
45+
46+
describe "weight_only_quantized_dense" do
47+
test "inits and executes properly" do
48+
model =
49+
Axon.input("input")
50+
|> Axon.Quantization.weight_only_quantized_dense(10)
51+
52+
assert {init_fn, _} = Axon.build(model)
53+
assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty())
54+
55+
assert {_, predict_fn} = Axon.build(model)
56+
assert predict_fn.(model_state, Nx.broadcast(1.0, {1, 1}))
57+
end
58+
end
4559
end

0 commit comments

Comments
 (0)