Skip to content

Commit a73f3e8

Browse files
committed
feat: hilbert curves
1 parent e07fe0b commit a73f3e8

File tree

11 files changed

+2054
-120
lines changed

11 files changed

+2054
-120
lines changed

dataset_preprocessing.ipynb

Lines changed: 377 additions & 0 deletions
Large diffs are not rendered by default.

demo_hilbert_curve.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Demo script for visualizing Hilbert curve patching in Vision Transformers.
4+
5+
This script demonstrates:
6+
1. How a Hilbert curve maps through an image grid
7+
2. How patching/unpatching with Hilbert ordering works
8+
3. Visual comparison between row-major and Hilbert curve patch ordering
9+
10+
Usage:
11+
python demo_hilbert_curve.py [--image IMAGE_PATH] [--patch_size PATCH_SIZE]
12+
13+
Options:
14+
--image: Path to an image file (default: will use a sample image)
15+
--patch_size: Size of patches (default: 16)
16+
"""
17+
import os
18+
os.environ["JAX_PLATFORMS"] = "cpu"
19+
import argparse
20+
import numpy as np
21+
import jax
22+
import jax.numpy as jnp
23+
import matplotlib.pyplot as plt
24+
from PIL import Image
25+
import requests
26+
from io import BytesIO
27+
import os
28+
import cv2
29+
from flaxdiff.models.hilbert import (
30+
visualize_hilbert_curve,
31+
demo_hilbert_patching,
32+
hilbert_patchify,
33+
hilbert_unpatchify,
34+
hilbert_indices,
35+
inverse_permutation,
36+
patchify
37+
)
38+
39+
def load_sample_image():
40+
"""Load a sample image if no image path is provided."""
41+
print("Downloading a sample image...")
42+
# Use a relatively small but detailed image
43+
url = 'https://www.caledoniaplay.com/wp-content/uploads/2016/01/EDU-PRODUCT-DESCRIPTION-gallery-image-OUTDOOR-SEATING-RUSTIC-LOG-BENCH-1-555x462.jpg'
44+
response = requests.get(url)
45+
img = Image.open(BytesIO(response.content))
46+
return np.array(img) / 255.0 # Normalize to [0, 1]
47+
48+
def load_image(path):
49+
"""Load an image from the given path."""
50+
img = Image.open(path)
51+
# Convert to RGB if needed
52+
if img.mode != 'RGB':
53+
img = img.convert('RGB')
54+
# Resize to ensure dimensions are divisible by patch_size
55+
w, h = img.size
56+
print(f"Loaded image of size: {img.size}")
57+
img = np.array(img) / 255.0 # Normalize to [0, 1]
58+
return img
59+
60+
def main():
61+
parser = argparse.ArgumentParser(description='Demonstrate Hilbert curve patching for ViTs')
62+
parser.add_argument('--image', type=str, default=None, help='Path to input image')
63+
parser.add_argument('--patch_size', type=int, default=16, help='Patch size')
64+
args = parser.parse_args()
65+
66+
# Load image
67+
if args.image and os.path.exists(args.image):
68+
print(f"Loading image from {args.image}...")
69+
image = load_image(args.image)
70+
else:
71+
image = load_sample_image()
72+
73+
print(f"Original image shape: {image.shape}")
74+
image = cv2.resize(image, (512, 512)) # Resize to a fixed size for demo
75+
print(f"Image shape: {image.shape}")
76+
# Ensure image dimensions are divisible by patch_size
77+
h, w = image.shape[:2]
78+
patch_size = args.patch_size
79+
80+
# Crop to make dimensions divisible by patch_size
81+
new_h = (h // patch_size) * patch_size
82+
new_w = (w // patch_size) * patch_size
83+
if new_h != h or new_w != w:
84+
print(f"Cropping image from {h}x{w} to {new_h}x{new_w} to make divisible by patch size {patch_size}")
85+
image = image[:new_h, :new_w]
86+
87+
# 1. Visualize the Hilbert curve mapping
88+
print("\n1. Visualizing Hilbert curve mapping...")
89+
fig_map = visualize_hilbert_curve(new_h, new_w, patch_size)
90+
91+
# 2. Demonstrate the patching process
92+
print("\n2. Demonstrating Hilbert curve patching...")
93+
fig_demo, fig_recon = demo_hilbert_patching(image, patch_size)
94+
95+
# 3. Additional example: Process through a simulated transformer block
96+
print("\n3. Simulating how patches would flow through a transformer...")
97+
98+
# Convert to JAX array and add batch dimension
99+
jax_img = jnp.array(image)[None, ...] # [1, H, W, C]
100+
101+
# Get Hilbert curve patches and inverse indices
102+
patches, inv_idx = hilbert_patchify(jax_img, patch_size)
103+
104+
print(f"Original image shape: {jax_img.shape}")
105+
print(f"Patches shape: {patches.shape}")
106+
107+
# Simulate a transformer block that operates on the patch sequence
108+
def simulate_transformer_block(patches):
109+
"""
110+
Simulate a transformer block by applying a simple operation to patches.
111+
For demonstration purposes, we'll just multiply by a learned weight matrix.
112+
"""
113+
batch, n_patches, patch_dim = patches.shape
114+
115+
# Simulate learned weights (identity + small random values)
116+
key = jax.random.PRNGKey(42)
117+
weights = jnp.eye(patch_dim) + jax.random.normal(key, (patch_dim, patch_dim)) * 0.05
118+
119+
# Apply "attention" (just a matrix multiply for demo)
120+
return jnp.matmul(patches, weights)
121+
122+
# Process patches as if through a transformer
123+
processed_patches = simulate_transformer_block(patches)
124+
125+
# Unpatchify back to image space
126+
h, w, c = jax_img.shape[1:]
127+
reconstructed = hilbert_unpatchify(processed_patches, inv_idx, patch_size, h, w, c)
128+
129+
# Visualize the processed result
130+
fig_processed, ax = plt.subplots(1, 2, figsize=(12, 5))
131+
ax[0].imshow(np.array(jax_img[0]))
132+
ax[0].set_title("Original Image")
133+
ax[0].axis('off')
134+
135+
ax[1].imshow(np.clip(np.array(reconstructed[0]), 0, 1))
136+
ax[1].set_title("After Simulated Transformer Processing")
137+
ax[1].axis('off')
138+
plt.tight_layout()
139+
140+
# Save all figures
141+
print("\nSaving visualization figures...")
142+
fig_map.savefig("hilbert_curve_mapping.png")
143+
fig_demo.savefig("hilbert_patch_demo.png")
144+
fig_recon.savefig("hilbert_patch_reconstruction.png")
145+
fig_processed.savefig("hilbert_transformer_simulation.png")
146+
147+
print("\nDone! Check the following output files:")
148+
print("- hilbert_curve_mapping.png - Visualizes how Hilbert curve maps through a grid")
149+
print("- hilbert_patch_demo.png - Shows patch ordering comparison")
150+
print("- hilbert_patch_reconstruction.png - Shows original vs reconstructed image")
151+
print("- hilbert_transformer_simulation.png - Shows a simple simulated transformer effect")
152+
153+
# Display plots if running in interactive environment
154+
plt.show()
155+
156+
if __name__ == "__main__":
157+
main()

flaxdiff/models/common.py

Lines changed: 1 addition & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -335,73 +335,4 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
335335

336336
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
337337

338-
return out
339-
340-
# Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
341-
def _d2xy(n, d):
342-
x = 0
343-
y = 0
344-
t = d
345-
s = 1
346-
while s < n:
347-
rx = (t // 2) & 1
348-
ry = (t ^ rx) & 1
349-
if ry == 0:
350-
if rx == 1:
351-
x = n - 1 - x
352-
y = n - 1 - y
353-
x, y = y, x
354-
x += s * rx
355-
y += s * ry
356-
t //= 4
357-
s *= 2
358-
return x, y
359-
360-
# Hilbert index mapping for a rectangular grid of patches H_P x W_P
361-
362-
def hilbert_indices(H_P, W_P):
363-
size = max(H_P, W_P)
364-
order = math.ceil(math.log2(size))
365-
n = 1 << order
366-
coords = []
367-
for d in range(n * n):
368-
x, y = _d2xy(n, d)
369-
# x is column index, y is row index
370-
if x < W_P and y < H_P:
371-
coords.append((y, x)) # (row, col)
372-
if len(coords) == H_P * W_P:
373-
break
374-
# Convert (row, col) to linear indices row-major
375-
indices = [r * W_P + c for r, c in coords]
376-
return jnp.array(indices, dtype=jnp.int32)
377-
378-
# Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
379-
380-
def inverse_permutation(idx):
381-
inv = jnp.zeros_like(idx)
382-
inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
383-
return inv
384-
385-
# Patchify using Hilbert ordering: extract patches and reorder sequence
386-
387-
def hilbert_patchify(x, patch_size):
388-
B, H, W, C = x.shape
389-
H_P = H // patch_size
390-
W_P = W // patch_size
391-
# Extract patches in row-major
392-
patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
393-
idx = hilbert_indices(H_P, W_P)
394-
return patches[:, idx, :]
395-
396-
# Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
397-
398-
def hilbert_unpatchify(patches, patch_size, H, W, C):
399-
B, N, D = patches.shape
400-
H_P = H // patch_size
401-
W_P = W // patch_size
402-
inv = inverse_permutation(hilbert_indices(H_P, W_P))
403-
# Reorder back to row-major
404-
linear = patches[:, inv, :]
405-
# Reconstruct image
406-
x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
407-
return x
338+
return out

0 commit comments

Comments
 (0)