Skip to content

Commit da18bb3

Browse files
committed
Merge coefficient types on linear interpolation, closes #318
1 parent b815c59 commit da18bb3

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

lib/scholar/interpolation/linear.ex

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,15 @@ defmodule Scholar.Interpolation.Linear do
131131

132132
defnp predict_n(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x, opts) do
133133
shape = Nx.shape(target_x)
134-
135134
target_x = Nx.flatten(target_x)
136-
137135
indices = Nx.argsort(target_x)
138136

139137
left_bound = x[0]
140138
right_bound = x[-1]
141139

142140
target_x = Nx.sort(target_x)
143-
res = Nx.broadcast(Nx.tensor(0, type: to_float_type(target_x)), {Nx.axis_size(target_x, 0)})
141+
type = Nx.Type.merge(to_float_type(target_x), coefficients.type)
142+
res = Nx.broadcast(Nx.tensor(0, type: type), {Nx.axis_size(target_x, 0)})
144143

145144
# while with smaller than left_bound
146145
{{res, i}, _} =

test/scholar/interpolation/linear_test.exs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,16 @@ defmodule Scholar.Interpolation.LinearTest do
5151
assert Linear.predict(model, Nx.tensor([[[-0.5], [0.5], [1.5], [2.5], [3.5]]])) ==
5252
Nx.tensor([[[0.0], [0.0], [0.5], [3], [7]]])
5353
end
54+
55+
test "with different types" do
56+
x_s = Nx.tensor([1, 2, 3], type: :u64)
57+
y_s = Nx.tensor([1.0, 2.0, 3.0], type: :f64)
58+
target = Nx.tensor([1, 2], type: :u64)
59+
60+
assert x_s
61+
|> Scholar.Interpolation.Linear.fit(y_s)
62+
|> Scholar.Interpolation.Linear.predict(target) ==
63+
Nx.tensor([1.0, 2.0], type: :f64)
64+
end
5465
end
5566
end

0 commit comments

Comments
 (0)