@@ -12,6 +12,7 @@ Mix.install([
12
12
13
13
Nx .global_default_backend (EXLA .Backend )
14
14
Nx .Defn .global_default_options (compiler: EXLA )
15
+
15
16
```
16
17
17
18
## Dataset
@@ -54,7 +55,6 @@ In metric learning, we don’t hand the model lone examples, instead we show it
54
55
class_idx_to_train_idxs =
55
56
bin
56
57
|> Nx .from_binary (type)
57
- |> Nx .reshape (shape)
58
58
|> Nx .to_flat_list ()
59
59
|> Enum .with_index ()
60
60
|> Enum .group_by (& elem (&1 , 0 ), fn {_ , i} -> i end )
@@ -64,7 +64,6 @@ class_idx_to_train_idxs =
64
64
class_idx_to_test_idxs =
65
65
bin
66
66
|> Nx .from_binary (type)
67
- |> Nx .reshape (shape)
68
67
|> Nx .to_flat_list ()
69
68
|> Enum .with_index ()
70
69
|> Enum .group_by (& elem (&1 , 0 ), fn {_ , i} -> i end )
@@ -80,22 +79,15 @@ With the index in place, the training loop draws one anchor and one sibling set
80
79
``` elixir
81
80
defmodule GetImages do
82
81
def batch (train_images, class_idx_to_train_idxs) do
83
- anchors_idx = Enum .map (0 .. 9 , fn class ->
84
- indices = class_idx_to_train_idxs[class]
85
- Enum .random (indices)
86
- end )
87
-
88
- positives_idx = Enum .map (0 .. 9 , fn class ->
89
- indices = class_idx_to_train_idxs[class]
90
- # Exclude the anchor from possible positives
91
- anchor_idx = Enum .at (anchors_idx, class)
92
- indices
93
- |> Enum .filter (fn idx -> idx != anchor_idx end )
94
- |> Enum .random ()
95
- end )
82
+ {anchors_idx, positives_idx} =
83
+ Enum .unzip (for class <- 0 .. 9 do
84
+ [a, p] = Enum .take_random (class_idx_to_train_idxs[class], 2 )
85
+ {a, p}
86
+ end )
96
87
97
88
anchors = Nx .take (train_images, Nx .tensor (anchors_idx)) |> Nx .reshape ({10 , 32 , 32 , 3 })
98
89
positives = Nx .take (train_images, Nx .tensor (positives_idx)) |> Nx .reshape ({10 , 32 , 32 , 3 })
90
+
99
91
{anchors, positives}
100
92
end
101
93
end
@@ -155,13 +147,9 @@ defmodule MetricModel do
155
147
end
156
148
157
149
defn normalize (x) do
158
- den =
159
- Nx .multiply (x, x)
160
- |> Nx .sum (axes: [- 1 ], keep_axes: true )
161
- |> Nx .sqrt ()
162
- den = Nx .max (den, 1.0e-7 )
163
- Nx .divide (x, den)
164
- end
150
+ norm = Nx .LinAlg .norm (x, axes: [- 1 ], keep_axes: true )
151
+ Nx .divide (x, norm)
152
+ end
165
153
166
154
end
167
155
```
@@ -211,7 +199,7 @@ The training loop then uses that loss to nudge parameters, pulling same-class ve
211
199
defmodule MetricLearning do
212
200
import Nx .Defn
213
201
require Logger
214
-
202
+
215
203
defn objective_fn (predict_fn, params, {anchor, positive}) do
216
204
%{prediction: anchor_embeddings} = predict_fn .(params, %{" input" => anchor})
217
205
%{prediction: positive_embeddings} = predict_fn .(params, %{" input" => positive})
@@ -304,43 +292,40 @@ near_neighbors_per_example = 10
304
292
305
293
embeddings = Nx .rename (embeddings, [nil , nil ])
306
294
gram_matrix = Nx .dot (embeddings, Nx .transpose (embeddings))
295
+
307
296
{_vals , neighbors} = Nx .top_k (gram_matrix, k: near_neighbors_per_example + 1 )
297
+
308
298
:ok
309
299
```
310
300
311
- To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we randomly pick one example and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors so you can see which images the network considers its nearest matches.
301
+ To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we pick the first example in each class and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors to see which images the network considers its nearest matches.
312
302
313
303
``` elixir
314
304
# take first image of each class
315
305
example_per_class_idx =
316
306
0 .. 9
317
307
|> Enum .map (fn class_idx ->
318
- class_idx_to_test_idxs[class_idx] |> Enum .random ( )
308
+ class_idx_to_test_idxs[class_idx] |> Enum .at ( 0 )
319
309
end )
320
310
|> Nx .tensor (type: {:s , 64 })
321
311
322
312
# take nearest neighbors for each example
323
313
neighbors_for_samples = Nx .take (neighbors, example_per_class_idx, axis: 0 )
324
314
325
- # show the ten closest images
326
- images = for row_idx <- 0 .. 9 do
327
- neighbour_idxs =
328
- neighbors_for_samples
329
- |> Nx .slice ([row_idx, 0 ], [1 , near_neighbors_per_example])
330
- |> Nx .squeeze ()
315
+ neighbour_idxs =
316
+ neighbors_for_samples
331
317
|> Nx .to_flat_list ()
332
318
333
- images =
334
- for idx <- neighbour_idxs do
335
- test_images
336
- |> Nx .take (Nx .tensor ([idx]), axis: 0 )
337
- |> Nx .squeeze ()
338
- |> Nx .transpose (axes: [:width , :height , :channels ])
339
- |> create_kino_image .()
340
- end
319
+ images =
320
+ for idx <- neighbour_idxs do
321
+ test_images[idx]
322
+ |> Nx .squeeze ()
323
+ |> Nx .transpose (axes: [:width , :height , :channels ])
324
+ |> create_kino_image .()
325
+ end
326
+
327
+ Kino .render (Kino .Layout .grid (images, columns: 11 ))
341
328
342
- Kino .render (Kino .Layout .grid (images, columns: near_neighbors_per_example))
343
- end
344
329
:ok
345
330
```
346
331
0 commit comments