From 1ceae68a8b3111d631734eb4718c1bd809c0daa6 Mon Sep 17 00:00:00 2001 From: Mike Cargian Date: Mon, 9 Jun 2025 11:18:06 -0400 Subject: [PATCH 1/3] Add metric learning for image similarity search example livebook --- notebooks/vision/metric-learning.livemd | 404 ++++++++++++++++++++++++ 1 file changed, 404 insertions(+) create mode 100644 notebooks/vision/metric-learning.livemd diff --git a/notebooks/vision/metric-learning.livemd b/notebooks/vision/metric-learning.livemd new file mode 100644 index 000000000..7f82de067 --- /dev/null +++ b/notebooks/vision/metric-learning.livemd @@ -0,0 +1,404 @@ +# Metric Learning + +```elixir +Mix.install([ + {:nx, "~> 0.9.2"}, + {:axon, "~> 0.7.0"}, + {:scidata, "~> 0.1.9"}, + {:exla, "~> 0.9.2"}, + {:stb_image, "~> 0.5.2"}, + {:kino_vega_lite, "~>0.1.13"} +]) + +Nx.global_default_backend(EXLA.Backend) +Nx.Defn.global_default_options(compiler: EXLA) +``` + +## Dataset + +We will be using the CIFAR-10 dataset. It consists of 50,000 color images, each 32x32 pixels, divided evenly into 10 categories such as airplanes, cars, birds, and cats and another 10,000 test images. + +```elixir +{train_images, train_labels} = Scidata.CIFAR10.download() +{test_images, test_labels} = Scidata.CIFAR10.download_test() +``` + +After downloading the dataset, we have 50,000 training images and 10,000 test images. The images are stored in binary (`bin`) format with the shape `{50000, 3, 32, 32}`, representing 50,000 images with 3 color channels (RGB), each 32x32 pixels in size. To prepare the data for training, we normalize the pixel values by dividing by 255, scaling them from their original 0-255 range to a 0-1 range. This normalization helps the model train faster and more efficiently. + +```elixir +normalize_images = fn images -> + {bin, type, shape} = images + bin + |> Nx.from_binary(type) + |> Nx.reshape(shape, names: [:count, :channels, :width, :height]) + # Move channels to last position to match what conv layer expects + |> Nx.transpose(axes: [:count, :width, :height, :channels]) + |> Nx.divide(255.0) +end + +train_images = normalize_images.(train_images) +test_images = normalize_images.(test_images) + +:ok +``` + +### Create anchor, positive pairs + + + +In metric learning, we don’t hand the model lone examples, instead we show it sibling snapshots. Each training step picks an anchor (e.g. a random airplane photo) and a positive (another airplane shot). By treating those two as twins, the network learns to pull same-class images closer in its feature space. But to grab those pairs on the fly, we first build a simple index that groups each class label (0–9) with its image indices. + +```elixir +{bin, type, shape} = train_labels + +class_idx_to_train_idxs = + bin + |> Nx.from_binary(type) + |> Nx.reshape(shape) + |> Nx.to_flat_list() + |> Enum.with_index() + |> Enum.group_by(&elem(&1, 0), fn {_, i} -> i end) + +{bin, type, shape} = test_labels + +class_idx_to_test_idxs = + bin + |> Nx.from_binary(type) + |> Nx.reshape(shape) + |> Nx.to_flat_list() + |> Enum.with_index() + |> Enum.group_by(&elem(&1, 0), fn {_, i} -> i end) + +:ok + +``` + +### Select images + +With the index in place, the training loop draws one anchor and one sibling set from each of the 10 classes. A helper module picks them at random, for each class 0 through 9, so the model samples fresh pairs from every category each batch. + +```elixir +defmodule GetImages do + def batch(train_images, class_idx_to_train_idxs) do + anchors_idx = Enum.map(0..9, fn class -> + indices = class_idx_to_train_idxs[class] + Enum.random(indices) + end) + + positives_idx = Enum.map(0..9, fn class -> + indices = class_idx_to_train_idxs[class] + # Exclude the anchor from possible positives + anchor_idx = Enum.at(anchors_idx, class) + indices + |> Enum.filter(fn idx -> idx != anchor_idx end) + |> Enum.random() + end) + + anchors = Nx.take(train_images, Nx.tensor(anchors_idx)) |> Nx.reshape({10, 32, 32, 3}) + positives = Nx.take(train_images, Nx.tensor(positives_idx)) |> Nx.reshape({10, 32, 32, 3}) + {anchors, positives} + end +end + +:ok +``` + +### Example images + +To peek into our dataset, we’ll display one anchor–positive twin from each class. We’ll use this create_kino_image helper to turn raw tensors into 64 × 64 PNG images. + +```elixir +create_kino_image = fn image -> + image + |> Nx.multiply(255) + |> Nx.as_type(:u8) + |> StbImage.from_nx() + |> StbImage.resize(64, 64) + |> StbImage.to_binary(:png) + |> Kino.Image.new(:png) +end + +{anchors, positives} = GetImages.batch(train_images, class_idx_to_train_idxs) + +images = + [anchors, positives] + |> Enum.flat_map(fn tensor -> + Enum.map(0..9, fn i -> + tensor + |> Nx.take(Nx.tensor([i]), axis: 0) + |> Nx.squeeze() + |> create_kino_image.() + end) + end) + +Kino.Layout.grid(images, columns: 10) +``` + +### Embedding model + +Our embedding network applies three 2D convolutional blocks (with ReLU activations and down‐sampling) before collapsing spatial dimensions via global average pooling. A final dense layer then projects into an 8-dimensional embedding space, and we ℓ₂-normalize each vector. + +In simpler terms, our detector scans each image with a small window three times—each pass spotting edges and textures while shedding unneeded detail. It then averages those responses and feeds them into a final layer that spits out eight number vector: a compact "fingerprint." We stretch each vector so it all lives on the same unit circle, then train by showing "same" or "different" pairs—pulling matching vectors together and pushing the rest apart. + +```elixir +defmodule MetricModel do + import Nx.Defn + + def build_model do + Axon.input("input", shape: {nil, 32, 32, 3}) + |> Axon.conv(32, kernel_size: 3, strides: 2, activation: :relu, name: "conv32") + |> Axon.conv(64, kernel_size: 3, strides: 2, activation: :relu, name: "conv64") + |> Axon.conv(128, kernel_size: 3, strides: 2, activation: :relu, name: "conv128") + |> Axon.global_avg_pool() + |> Axon.dense(8) + |> Axon.nx(&normalize/1) + end + + defn normalize(x) do + den = + Nx.multiply(x, x) + |> Nx.sum(axes: [-1], keep_axes: true) + |> Nx.sqrt() + den = Nx.max(den, 1.0e-7) + Nx.divide(x, den) + end + +end +``` + +To monitor training progress, we use a KinoAxon module that plots the loss at the end of each epoch. It hooks into the training loop by handling the :epoch_completed event, extracts the current loss, and streams it to a live VegaLite chart. + +```elixir +alias VegaLite, as: Vl + +defmodule KinoAxon do + def plot_losses(loop) do + vl_widget = + Vl.new(width: 600, height: 400) + |> Vl.mark(:point, tooltip: true) + |> Vl.encode_field(:x, "epoch", type: :ordinal, title: "Epoch") + |> Vl.encode_field(:y, "loss", + type: :quantitative, + scale: [zero: false, nice: true], + title: "Loss" + ) + |> Vl.encode_field(:color, "dataset", type: :nominal) + |> Kino.VegaLite.new() + |> Kino.render() + + handler = fn state -> + %Axon.Loop.State{epoch: epoch, iteration: _iter, step_state: step_state} = state + Kino.VegaLite.push_many(vl_widget, [%{epoch: epoch, loss: Nx.to_number(step_state[:epoch_avg_loss]), dataset: "train"}]) + {:continue, state} + end + + Axon.Loop.handle_event(loop, :epoch_completed, handler) + end +end +``` + +### Embedding Model + +We wrap our model in a custom train_step that does three things each batch: + +* Embed both anchors and positives through the network. +* Score them by taking dot products (our raw “how-similar?” numbers). +* Compute a softmax cross-entropy loss over those scores—using temperature scaling to sharpen or soften the comparison. + +The training loop then uses that loss to nudge parameters, pulling same-class vectors together and pushing others apart, one batch at a time. + +```elixir +defmodule MetricLearning do + import Nx.Defn + require Logger + + defn objective_fn(predict_fn, params, {anchor, positive}) do + %{prediction: anchor_embeddings} = predict_fn.(params, %{"input" => anchor}) + %{prediction: positive_embeddings} = predict_fn.(params, %{"input" => positive}) + + similarities = Nx.dot(anchor_embeddings, [1], positive_embeddings, [1]) + temperature = 0.2 + similarities = similarities / temperature + sparse_labels = Nx.iota({10}) + + Axon.Losses.categorical_cross_entropy(sparse_labels, similarities, + reduction: :mean, + sparse: true, + from_logits: true + ) + end + + defn batch_step(predict_fn, optim, {anchor, positive}, state) do + # Compute gradient of objective defined above + {loss, gradients} = + value_and_grad(state.model_state, &objective_fn(predict_fn, &1, {anchor, positive})) + + {updates, new_optimizer_state} = optim.(gradients, state.optimizer_state, state.model_state) + new_params = Polaris.Updates.apply_updates(state.model_state, updates) + + %{ + state + | model_state: new_params, + optimizer_state: new_optimizer_state, + epoch_loss: state[:epoch_loss] + loss, + epoch_count: state[:epoch_count] + 1, + epoch_avg_loss: (state[:epoch_loss] + loss) / (state[:epoch_count] + 1) + } + end + + def init(template, init_fn, init_optim) do + model_state = init_fn.(template, Axon.ModelState.empty()) + + %{ + model_state: model_state, + optimizer_state: init_optim.(model_state), + epoch_loss: Nx.tensor(0.0), + epoch_count: Nx.tensor(0), + epoch_avg_loss: Nx.tensor(0.0) + } + end + + def run(train_images, class_idx_to_train_idxs) do + + {optim_init_fn, optim_update_fn} = Polaris.Optimizers.adam() + {init_fn, predict_fn} = Axon.build(MetricModel.build_model(), mode: :train, debug: false) + + step = &batch_step(predict_fn, optim_update_fn, &1, &2) + init = fn {template, _}, _state -> init(%{"input" => template}, init_fn, optim_init_fn) end + + training_data = Stream.repeatedly(fn -> + GetImages.batch(train_images, class_idx_to_train_idxs) + end) + + final_state = + Axon.Loop.loop(step, init) + |> Axon.Loop.log( + fn %Axon.Loop.State{epoch: epoch, step_state: state} -> + loss_str = :io_lib.format(~c"~.4f", [Nx.to_number(state[:epoch_avg_loss])]) + "\rEpoch: #{epoch}, Loss: #{loss_str}\n" + end, + event: :epoch_completed + ) + |> KinoAxon.plot_losses() + |> Axon.Loop.run(training_data, %{}, iterations: 1000, epochs: 20, compiler: EXLA) + + {final_state, predict_fn} + + end + +end + +{final_state, predict_fn} = MetricLearning.run(train_images, class_idx_to_train_idxs) + +:ok +``` + +### Testing + +After training, we pass every test image through our network to get its embedding. We then build a "who’s most like whom" table by dot-producting each embedding against all the others, and finally call Nx.top_k to pull out the ten closest cousins for each image. + +```elixir +final_params = final_state.step_state.model_state +near_neighbors_per_example = 10 +%{prediction: embeddings} = predict_fn.(final_params, %{"input" => test_images}) + +embeddings = Nx.rename(embeddings, [nil, nil]) +gram_matrix = Nx.dot(embeddings, Nx.transpose(embeddings)) +{_vals, neighbors} = Nx.top_k(gram_matrix, k: near_neighbors_per_example + 1) +:ok +``` + +To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we randomly pick one example and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors so you can see which images the network considers its nearest matches. + +```elixir +# take first image of each class +example_per_class_idx = + 0..9 + |> Enum.map(fn class_idx -> + class_idx_to_test_idxs[class_idx] |> Enum.random() + end) + |> Nx.tensor(type: {:s, 64}) + +# take nearest neighbors for each example +neighbors_for_samples = Nx.take(neighbors, example_per_class_idx, axis: 0) + +# show the ten closest images +images = for row_idx <- 0..9 do + neighbour_idxs = + neighbors_for_samples + |> Nx.slice([row_idx, 0], [1, near_neighbors_per_example]) + |> Nx.squeeze() + |> Nx.to_flat_list() + + images = + for idx <- neighbour_idxs do + test_images + |> Nx.take(Nx.tensor([idx]), axis: 0) + |> Nx.squeeze() + |> Nx.transpose(axes: [:width, :height, :channels]) + |> create_kino_image.() + end + + Kino.render(Kino.Layout.grid(images, columns: near_neighbors_per_example)) +end +:ok +``` + +### Confusion Matrix + +To measure performance numerically, we treat each example’s nearest neighbors as a simple classifier and summarize the results in a confusion matrix. We pick ten images from each of the ten classes, find the ten closest embeddings for each one, and ask: “Do these neighbors share the same label as our query image?” Each group of the predicted classes is compared against the true class to populate the matrix. + +```elixir +{bin, type, shape} = test_labels +test_labels_tensor = + bin + |> Nx.from_binary(type) + |> Nx.reshape(shape) + +# take 10 images from each class in the test set +test_idxs = + 0..9 + |> Enum.flat_map(fn class_idx -> + class_idx_to_test_idxs[class_idx] + |> Enum.take(10) + end) + +# 100 copies of each class (for 10 neighbors per 10 examples) +actual_classes = + 0..9 + |> Enum.flat_map(fn class_idx -> + List.duplicate(class_idx, 100) # 100 copies of each class_idx + end) + +predicted_classes = + neighbors + |> Nx.take(Nx.tensor(test_idxs), axis: 0) + |> Nx.slice([0, 1], [100, 10]) # 100 examples × 10 neighbors, skip self + |> Nx.to_flat_list() + |> Enum.map(fn idx -> Nx.to_number(test_labels_tensor[idx]) end) + +Vl.new(title: "Confusion Matrix", width: 700, height: 700) + |> Vl.data_from_values(%{ + predicted: predicted_classes, + actual: actual_classes + }) + |> Vl.layers([ + # First layer: draw the rects with color encoding + Vl.new() + |> Vl.mark(:rect, tooltip: false) + |> Vl.encode_field(:x, "predicted", title: "Predicted Label") + |> Vl.encode_field(:y, "actual", title: "True Label") + |> Vl.encode(:color, aggregate: :count, legend: [title: "Matches"]), + + # Second layer: add the count as centered text + Vl.new() + |> Vl.mark(:text, + align: "center", + baseline: "middle", + font_size: 16 + ) + |> Vl.encode_field(:x, "predicted") + |> Vl.encode_field(:y, "actual") + |> Vl.encode(:text, aggregate: :count) + ]) +``` From 51e18936b18e99ecb412f2d930469c442d221ffe Mon Sep 17 00:00:00 2001 From: Mike Cargian Date: Fri, 13 Jun 2025 16:08:23 -0400 Subject: [PATCH 2/3] metric learning PR updates --- notebooks/vision/metric-learning.livemd | 69 ++++++++++--------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/notebooks/vision/metric-learning.livemd b/notebooks/vision/metric-learning.livemd index 7f82de067..babe12d24 100644 --- a/notebooks/vision/metric-learning.livemd +++ b/notebooks/vision/metric-learning.livemd @@ -12,6 +12,7 @@ Mix.install([ Nx.global_default_backend(EXLA.Backend) Nx.Defn.global_default_options(compiler: EXLA) + ``` ## Dataset @@ -54,7 +55,6 @@ In metric learning, we don’t hand the model lone examples, instead we show it class_idx_to_train_idxs = bin |> Nx.from_binary(type) - |> Nx.reshape(shape) |> Nx.to_flat_list() |> Enum.with_index() |> Enum.group_by(&elem(&1, 0), fn {_, i} -> i end) @@ -64,7 +64,6 @@ class_idx_to_train_idxs = class_idx_to_test_idxs = bin |> Nx.from_binary(type) - |> Nx.reshape(shape) |> Nx.to_flat_list() |> Enum.with_index() |> Enum.group_by(&elem(&1, 0), fn {_, i} -> i end) @@ -80,22 +79,15 @@ With the index in place, the training loop draws one anchor and one sibling set ```elixir defmodule GetImages do def batch(train_images, class_idx_to_train_idxs) do - anchors_idx = Enum.map(0..9, fn class -> - indices = class_idx_to_train_idxs[class] - Enum.random(indices) - end) + {anchors_idx, positives_idx} = + Enum.unzip(for class <- 0..9 do + [a, p] = Enum.take_random(class_idx_to_train_idxs[class], 2) + {a, p} + end) - positives_idx = Enum.map(0..9, fn class -> - indices = class_idx_to_train_idxs[class] - # Exclude the anchor from possible positives - anchor_idx = Enum.at(anchors_idx, class) - indices - |> Enum.filter(fn idx -> idx != anchor_idx end) - |> Enum.random() - end) + anchors = Nx.take(train_images, Nx.tensor(anchors_idx)) |> Nx.rename(nil) + positives = Nx.take(train_images, Nx.tensor(positives_idx)) |> Nx.rename(nil) - anchors = Nx.take(train_images, Nx.tensor(anchors_idx)) |> Nx.reshape({10, 32, 32, 3}) - positives = Nx.take(train_images, Nx.tensor(positives_idx)) |> Nx.reshape({10, 32, 32, 3}) {anchors, positives} end end @@ -155,13 +147,9 @@ defmodule MetricModel do end defn normalize(x) do - den = - Nx.multiply(x, x) - |> Nx.sum(axes: [-1], keep_axes: true) - |> Nx.sqrt() - den = Nx.max(den, 1.0e-7) - Nx.divide(x, den) - end + norm = Nx.LinAlg.norm(x, axes: [-1], keep_axes: true) + Nx.divide(x, norm) + end end ``` @@ -211,7 +199,7 @@ The training loop then uses that loss to nudge parameters, pulling same-class ve defmodule MetricLearning do import Nx.Defn require Logger - + defn objective_fn(predict_fn, params, {anchor, positive}) do %{prediction: anchor_embeddings} = predict_fn.(params, %{"input" => anchor}) %{prediction: positive_embeddings} = predict_fn.(params, %{"input" => positive}) @@ -304,43 +292,40 @@ near_neighbors_per_example = 10 embeddings = Nx.rename(embeddings, [nil, nil]) gram_matrix = Nx.dot(embeddings, Nx.transpose(embeddings)) + {_vals, neighbors} = Nx.top_k(gram_matrix, k: near_neighbors_per_example + 1) + :ok ``` -To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we randomly pick one example and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors so you can see which images the network considers its nearest matches. +To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we pick the first example in each class and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors to see which images the network considers its nearest matches. ```elixir # take first image of each class example_per_class_idx = 0..9 |> Enum.map(fn class_idx -> - class_idx_to_test_idxs[class_idx] |> Enum.random() + class_idx_to_test_idxs[class_idx] |> Enum.at(0) end) |> Nx.tensor(type: {:s, 64}) # take nearest neighbors for each example neighbors_for_samples = Nx.take(neighbors, example_per_class_idx, axis: 0) -# show the ten closest images -images = for row_idx <- 0..9 do - neighbour_idxs = - neighbors_for_samples - |> Nx.slice([row_idx, 0], [1, near_neighbors_per_example]) - |> Nx.squeeze() +neighbour_idxs = + neighbors_for_samples |> Nx.to_flat_list() - images = - for idx <- neighbour_idxs do - test_images - |> Nx.take(Nx.tensor([idx]), axis: 0) - |> Nx.squeeze() - |> Nx.transpose(axes: [:width, :height, :channels]) - |> create_kino_image.() - end +images = + for idx <- neighbour_idxs do + test_images[idx] + |> Nx.squeeze() + |> Nx.transpose(axes: [:width, :height, :channels]) + |> create_kino_image.() + end + +Kino.render(Kino.Layout.grid(images, columns: 11)) - Kino.render(Kino.Layout.grid(images, columns: near_neighbors_per_example)) -end :ok ``` From 5690b7f0a152f4a99f59aee6eed601d4cbd928fe Mon Sep 17 00:00:00 2001 From: Mike Cargian Date: Fri, 20 Jun 2025 18:12:49 -0400 Subject: [PATCH 3/3] avoid the transpose, instead contract the 8-sized axis in both tensors --- notebooks/vision/metric-learning.livemd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/vision/metric-learning.livemd b/notebooks/vision/metric-learning.livemd index babe12d24..94c48dfd4 100644 --- a/notebooks/vision/metric-learning.livemd +++ b/notebooks/vision/metric-learning.livemd @@ -291,7 +291,7 @@ near_neighbors_per_example = 10 %{prediction: embeddings} = predict_fn.(final_params, %{"input" => test_images}) embeddings = Nx.rename(embeddings, [nil, nil]) -gram_matrix = Nx.dot(embeddings, Nx.transpose(embeddings)) +gram_matrix = Nx.dot(embeddings, [1], embeddings, [1]) {_vals, neighbors} = Nx.top_k(gram_matrix, k: near_neighbors_per_example + 1)