Skip to content

Commit 7c50a51

Browse files
authored
fix: Nx.LinAlg.norm axes support (#1522)
1 parent 7a3d7cd commit 7c50a51

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

nx/lib/nx/lin_alg.ex

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@ defmodule Nx.LinAlg do
388388
# The idea is that by dividing the tensor by it, large values of
389389
# tensor entries and large values of p are reduced, which in turn
390390
# avoids numerical overflow.
391+
392+
keep_axes = opts[:keep_axes]
393+
394+
opts = Keyword.put(opts, :keep_axes, true)
391395
numerical_stability_coefficient = Nx.reduce_max(abs_t, opts)
392396

393397
# This code prevents from division by zero.
@@ -398,12 +402,19 @@ defmodule Nx.LinAlg do
398402
1
399403
)
400404

401-
abs_t
402-
|> Nx.divide(numerical_stability_coefficient)
403-
|> Nx.pow(ord)
404-
|> Nx.sum(opts)
405-
|> Nx.pow(inv_ord)
406-
|> Nx.multiply(numerical_stability_coefficient)
405+
result =
406+
abs_t
407+
|> Nx.divide(numerical_stability_coefficient)
408+
|> Nx.pow(ord)
409+
|> Nx.sum(opts)
410+
|> Nx.pow(inv_ord)
411+
|> Nx.multiply(numerical_stability_coefficient)
412+
413+
if keep_axes do
414+
result
415+
else
416+
Nx.squeeze(result, Keyword.take(opts, [:axes]))
417+
end
407418
end
408419

409420
@doc """

nx/test/nx/lin_alg_test.exs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,19 @@ defmodule Nx.LinAlgTest do
276276
Nx.LinAlg.norm(t, ord: -3)
277277
end)
278278
end
279+
280+
test "correctly support axes option" do
281+
t =
282+
Nx.tensor([
283+
[-1.0, -1.0],
284+
[0.0, 0.0],
285+
[1.0, 1.0]
286+
])
287+
288+
result = Nx.tensor([1.4142135381698608, 0.0, 1.4142135381698608])
289+
assert Nx.LinAlg.norm(t, axes: [1]) == result
290+
assert Nx.LinAlg.norm(t, axes: [1], keep_axes: true) == Nx.reshape(result, {3, 1})
291+
end
279292
end
280293

281294
describe "matrix_power" do

0 commit comments

Comments
 (0)