Skip to content

Commit 51e1893

Browse files
committed
metric learning PR updates
1 parent 1ceae68 commit 51e1893

File tree

1 file changed

+27
-42
lines changed

1 file changed

+27
-42
lines changed

notebooks/vision/metric-learning.livemd

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Mix.install([
1212

1313
Nx.global_default_backend(EXLA.Backend)
1414
Nx.Defn.global_default_options(compiler: EXLA)
15+
1516
```
1617

1718
## Dataset
@@ -54,7 +55,6 @@ In metric learning, we don’t hand the model lone examples, instead we show it
5455
class_idx_to_train_idxs =
5556
bin
5657
|> Nx.from_binary(type)
57-
|> Nx.reshape(shape)
5858
|> Nx.to_flat_list()
5959
|> Enum.with_index()
6060
|> Enum.group_by(&elem(&1, 0), fn {_, i} -> i end)
@@ -64,7 +64,6 @@ class_idx_to_train_idxs =
6464
class_idx_to_test_idxs =
6565
bin
6666
|> Nx.from_binary(type)
67-
|> Nx.reshape(shape)
6867
|> Nx.to_flat_list()
6968
|> Enum.with_index()
7069
|> Enum.group_by(&elem(&1, 0), fn {_, i} -> i end)
@@ -80,22 +79,15 @@ With the index in place, the training loop draws one anchor and one sibling set
8079
```elixir
8180
defmodule GetImages do
8281
def batch(train_images, class_idx_to_train_idxs) do
83-
anchors_idx = Enum.map(0..9, fn class ->
84-
indices = class_idx_to_train_idxs[class]
85-
Enum.random(indices)
86-
end)
82+
{anchors_idx, positives_idx} =
83+
Enum.unzip(for class <- 0..9 do
84+
[a, p] = Enum.take_random(class_idx_to_train_idxs[class], 2)
85+
{a, p}
86+
end)
8787

88-
positives_idx = Enum.map(0..9, fn class ->
89-
indices = class_idx_to_train_idxs[class]
90-
# Exclude the anchor from possible positives
91-
anchor_idx = Enum.at(anchors_idx, class)
92-
indices
93-
|> Enum.filter(fn idx -> idx != anchor_idx end)
94-
|> Enum.random()
95-
end)
88+
anchors = Nx.take(train_images, Nx.tensor(anchors_idx)) |> Nx.rename(nil)
89+
positives = Nx.take(train_images, Nx.tensor(positives_idx)) |> Nx.rename(nil)
9690

97-
anchors = Nx.take(train_images, Nx.tensor(anchors_idx)) |> Nx.reshape({10, 32, 32, 3})
98-
positives = Nx.take(train_images, Nx.tensor(positives_idx)) |> Nx.reshape({10, 32, 32, 3})
9991
{anchors, positives}
10092
end
10193
end
@@ -155,13 +147,9 @@ defmodule MetricModel do
155147
end
156148

157149
defn normalize(x) do
158-
den =
159-
Nx.multiply(x, x)
160-
|> Nx.sum(axes: [-1], keep_axes: true)
161-
|> Nx.sqrt()
162-
den = Nx.max(den, 1.0e-7)
163-
Nx.divide(x, den)
164-
end
150+
norm = Nx.LinAlg.norm(x, axes: [-1], keep_axes: true)
151+
Nx.divide(x, norm)
152+
end
165153

166154
end
167155
```
@@ -211,7 +199,7 @@ The training loop then uses that loss to nudge parameters, pulling same-class ve
211199
defmodule MetricLearning do
212200
import Nx.Defn
213201
require Logger
214-
202+
215203
defn objective_fn(predict_fn, params, {anchor, positive}) do
216204
%{prediction: anchor_embeddings} = predict_fn.(params, %{"input" => anchor})
217205
%{prediction: positive_embeddings} = predict_fn.(params, %{"input" => positive})
@@ -304,43 +292,40 @@ near_neighbors_per_example = 10
304292

305293
embeddings = Nx.rename(embeddings, [nil, nil])
306294
gram_matrix = Nx.dot(embeddings, Nx.transpose(embeddings))
295+
307296
{_vals, neighbors} = Nx.top_k(gram_matrix, k: near_neighbors_per_example + 1)
297+
308298
:ok
309299
```
310300

311-
To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we randomly pick one example and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors so you can see which images the network considers its nearest matches.
301+
To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we pick the first example in each class and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors to see which images the network considers its nearest matches.
312302

313303
```elixir
314304
# take first image of each class
315305
example_per_class_idx =
316306
0..9
317307
|> Enum.map(fn class_idx ->
318-
class_idx_to_test_idxs[class_idx] |> Enum.random()
308+
class_idx_to_test_idxs[class_idx] |> Enum.at(0)
319309
end)
320310
|> Nx.tensor(type: {:s, 64})
321311

322312
# take nearest neighbors for each example
323313
neighbors_for_samples = Nx.take(neighbors, example_per_class_idx, axis: 0)
324314

325-
# show the ten closest images
326-
images = for row_idx <- 0..9 do
327-
neighbour_idxs =
328-
neighbors_for_samples
329-
|> Nx.slice([row_idx, 0], [1, near_neighbors_per_example])
330-
|> Nx.squeeze()
315+
neighbour_idxs =
316+
neighbors_for_samples
331317
|> Nx.to_flat_list()
332318

333-
images =
334-
for idx <- neighbour_idxs do
335-
test_images
336-
|> Nx.take(Nx.tensor([idx]), axis: 0)
337-
|> Nx.squeeze()
338-
|> Nx.transpose(axes: [:width, :height, :channels])
339-
|> create_kino_image.()
340-
end
319+
images =
320+
for idx <- neighbour_idxs do
321+
test_images[idx]
322+
|> Nx.squeeze()
323+
|> Nx.transpose(axes: [:width, :height, :channels])
324+
|> create_kino_image.()
325+
end
326+
327+
Kino.render(Kino.Layout.grid(images, columns: 11))
341328

342-
Kino.render(Kino.Layout.grid(images, columns: near_neighbors_per_example))
343-
end
344329
:ok
345330
```
346331

0 commit comments

Comments
 (0)