1+ #!/usr/bin/env python3
2+ """
3+ Demo script for visualizing Hilbert curve patching in Vision Transformers.
4+
5+ This script demonstrates:
6+ 1. How a Hilbert curve maps through an image grid
7+ 2. How patching/unpatching with Hilbert ordering works
8+ 3. Visual comparison between row-major and Hilbert curve patch ordering
9+
10+ Usage:
11+ python demo_hilbert_curve.py [--image IMAGE_PATH] [--patch_size PATCH_SIZE]
12+
13+ Options:
14+ --image: Path to an image file (default: will use a sample image)
15+ --patch_size: Size of patches (default: 16)
16+ """
17+ import os
18+ os .environ ["JAX_PLATFORMS" ] = "cpu"
19+ import argparse
20+ import numpy as np
21+ import jax
22+ import jax .numpy as jnp
23+ import matplotlib .pyplot as plt
24+ from PIL import Image
25+ import requests
26+ from io import BytesIO
27+ import os
28+ import cv2
29+ from flaxdiff .models .hilbert import (
30+ visualize_hilbert_curve ,
31+ demo_hilbert_patching ,
32+ hilbert_patchify ,
33+ hilbert_unpatchify ,
34+ hilbert_indices ,
35+ inverse_permutation ,
36+ patchify
37+ )
38+
39+ def load_sample_image ():
40+ """Load a sample image if no image path is provided."""
41+ print ("Downloading a sample image..." )
42+ # Use a relatively small but detailed image
43+ url = 'https://www.caledoniaplay.com/wp-content/uploads/2016/01/EDU-PRODUCT-DESCRIPTION-gallery-image-OUTDOOR-SEATING-RUSTIC-LOG-BENCH-1-555x462.jpg'
44+ response = requests .get (url )
45+ img = Image .open (BytesIO (response .content ))
46+ return np .array (img ) / 255.0 # Normalize to [0, 1]
47+
48+ def load_image (path ):
49+ """Load an image from the given path."""
50+ img = Image .open (path )
51+ # Convert to RGB if needed
52+ if img .mode != 'RGB' :
53+ img = img .convert ('RGB' )
54+ # Resize to ensure dimensions are divisible by patch_size
55+ w , h = img .size
56+ print (f"Loaded image of size: { img .size } " )
57+ img = np .array (img ) / 255.0 # Normalize to [0, 1]
58+ return img
59+
60+ def main ():
61+ parser = argparse .ArgumentParser (description = 'Demonstrate Hilbert curve patching for ViTs' )
62+ parser .add_argument ('--image' , type = str , default = None , help = 'Path to input image' )
63+ parser .add_argument ('--patch_size' , type = int , default = 16 , help = 'Patch size' )
64+ args = parser .parse_args ()
65+
66+ # Load image
67+ if args .image and os .path .exists (args .image ):
68+ print (f"Loading image from { args .image } ..." )
69+ image = load_image (args .image )
70+ else :
71+ image = load_sample_image ()
72+
73+ print (f"Original image shape: { image .shape } " )
74+ image = cv2 .resize (image , (512 , 512 )) # Resize to a fixed size for demo
75+ print (f"Image shape: { image .shape } " )
76+ # Ensure image dimensions are divisible by patch_size
77+ h , w = image .shape [:2 ]
78+ patch_size = args .patch_size
79+
80+ # Crop to make dimensions divisible by patch_size
81+ new_h = (h // patch_size ) * patch_size
82+ new_w = (w // patch_size ) * patch_size
83+ if new_h != h or new_w != w :
84+ print (f"Cropping image from { h } x{ w } to { new_h } x{ new_w } to make divisible by patch size { patch_size } " )
85+ image = image [:new_h , :new_w ]
86+
87+ # 1. Visualize the Hilbert curve mapping
88+ print ("\n 1. Visualizing Hilbert curve mapping..." )
89+ fig_map = visualize_hilbert_curve (new_h , new_w , patch_size )
90+
91+ # 2. Demonstrate the patching process
92+ print ("\n 2. Demonstrating Hilbert curve patching..." )
93+ fig_demo , fig_recon = demo_hilbert_patching (image , patch_size )
94+
95+ # 3. Additional example: Process through a simulated transformer block
96+ print ("\n 3. Simulating how patches would flow through a transformer..." )
97+
98+ # Convert to JAX array and add batch dimension
99+ jax_img = jnp .array (image )[None , ...] # [1, H, W, C]
100+
101+ # Get Hilbert curve patches and inverse indices
102+ patches , inv_idx = hilbert_patchify (jax_img , patch_size )
103+
104+ print (f"Original image shape: { jax_img .shape } " )
105+ print (f"Patches shape: { patches .shape } " )
106+
107+ # Simulate a transformer block that operates on the patch sequence
108+ def simulate_transformer_block (patches ):
109+ """
110+ Simulate a transformer block by applying a simple operation to patches.
111+ For demonstration purposes, we'll just multiply by a learned weight matrix.
112+ """
113+ batch , n_patches , patch_dim = patches .shape
114+
115+ # Simulate learned weights (identity + small random values)
116+ key = jax .random .PRNGKey (42 )
117+ weights = jnp .eye (patch_dim ) + jax .random .normal (key , (patch_dim , patch_dim )) * 0.05
118+
119+ # Apply "attention" (just a matrix multiply for demo)
120+ return jnp .matmul (patches , weights )
121+
122+ # Process patches as if through a transformer
123+ processed_patches = simulate_transformer_block (patches )
124+
125+ # Unpatchify back to image space
126+ h , w , c = jax_img .shape [1 :]
127+ reconstructed = hilbert_unpatchify (processed_patches , inv_idx , patch_size , h , w , c )
128+
129+ # Visualize the processed result
130+ fig_processed , ax = plt .subplots (1 , 2 , figsize = (12 , 5 ))
131+ ax [0 ].imshow (np .array (jax_img [0 ]))
132+ ax [0 ].set_title ("Original Image" )
133+ ax [0 ].axis ('off' )
134+
135+ ax [1 ].imshow (np .clip (np .array (reconstructed [0 ]), 0 , 1 ))
136+ ax [1 ].set_title ("After Simulated Transformer Processing" )
137+ ax [1 ].axis ('off' )
138+ plt .tight_layout ()
139+
140+ # Save all figures
141+ print ("\n Saving visualization figures..." )
142+ fig_map .savefig ("hilbert_curve_mapping.png" )
143+ fig_demo .savefig ("hilbert_patch_demo.png" )
144+ fig_recon .savefig ("hilbert_patch_reconstruction.png" )
145+ fig_processed .savefig ("hilbert_transformer_simulation.png" )
146+
147+ print ("\n Done! Check the following output files:" )
148+ print ("- hilbert_curve_mapping.png - Visualizes how Hilbert curve maps through a grid" )
149+ print ("- hilbert_patch_demo.png - Shows patch ordering comparison" )
150+ print ("- hilbert_patch_reconstruction.png - Shows original vs reconstructed image" )
151+ print ("- hilbert_transformer_simulation.png - Shows a simple simulated transformer effect" )
152+
153+ # Display plots if running in interactive environment
154+ plt .show ()
155+
156+ if __name__ == "__main__" :
157+ main ()
0 commit comments