-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample_flow_matching.py
More file actions
executable file
·330 lines (276 loc) · 11.5 KB
/
sample_flow_matching.py
File metadata and controls
executable file
·330 lines (276 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/env python3
"""Sample from a flow matching checkpoint and save a visualization."""
import argparse
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from pathlib import Path
from txt2img_unsupervised.checkpoint import load_params
from txt2img_unsupervised.flow_matching import (
create_mollweide_projection_figure,
)
from txt2img_unsupervised.function_weighted_flow_model import (
BaseDistribution,
WeightingFunction,
sample_full_sphere,
sample_loop,
)
from txt2img_unsupervised.training_infra import setup_jax_for_training
def parse_arguments():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Sample from a flow matching model checkpoint and save visualization."
)
parser.add_argument(
"checkpoint_dir", type=Path, help="Directory containing the checkpoint"
)
parser.add_argument(
"--output", type=Path, default="samples.png", help="Output PNG file path"
)
parser.add_argument(
"--step", type=int, help="Specific checkpoint step to load (default: latest)"
)
parser.add_argument(
"--n-samples", type=int, default=1000, help="Number of samples to generate"
)
parser.add_argument(
"--batch-size", type=int, default=2048, help="Batch size for sampling"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for reproducibility"
)
parser.add_argument(
"--title", type=str, default=None, help="Title for the visualization"
)
# Cap conditioning parameters
parser.add_argument(
"--latitude", type=float, help="Latitude in degrees for cap center"
)
parser.add_argument(
"--longitude", type=float, help="Longitude in degrees for cap center"
)
parser.add_argument(
"--cap-radius", type=float, help="Angular radius of the cap in degrees"
)
parser.add_argument(
"--n-steps",
type=int,
default=8,
help="Number of integration steps for sampling",
)
return parser.parse_args()
def latlon_to_unit_vector(lon, lat):
"""
Convert longitude/latitude coordinates to 3D unit vector.
Args:
lon: Longitude in degrees
lat: Latitude in degrees
Returns:
3D unit vector [x, y, z]
"""
lon_rad = np.radians(lon)
lat_rad = np.radians(lat)
x = np.cos(lat_rad) * np.cos(lon_rad)
y = np.cos(lat_rad) * np.sin(lon_rad)
z = np.sin(lat_rad)
return np.array([x, y, z])
def create_mollweide_with_cap(samples, cap_center_latlon, cap_radius_deg, title=None):
"""
Create a Mollweide projection visualization of 3D points with a cap boundary.
Args:
samples: Array of 3D unit vectors with shape [n_samples, 3]
cap_center_latlon: Tuple of (latitude, longitude) in degrees for cap center
cap_radius_deg: Angular radius of the cap in degrees
title: Optional title for the figure
Returns:
matplotlib Figure object
"""
assert samples.shape[1] == 3, f"Expected 3D samples, got shape {samples.shape}"
fig = plt.figure(figsize=(16, 10), dpi=200)
ax = fig.add_subplot(111, projection="mollweide")
# Convert 3D coordinates to longitude/latitude
# Mollweide projection expects longitude in [-pi, pi] and latitude in [-pi/2, pi/2]
longitude = np.arctan2(samples[:, 1], samples[:, 0]) # atan2(y, x) for longitude
latitude = np.arcsin(samples[:, 2]) # z-coordinate gives latitude (arcsin)
scatter = ax.scatter(longitude, latitude, s=8, alpha=0.7)
# Draw cap boundary
lat_center, lon_center = cap_center_latlon
cap_radius_rad = np.radians(cap_radius_deg)
lon_center_rad = np.radians(lon_center)
lat_center_rad = np.radians(lat_center)
# Create points along the cap boundary (a small circle on the sphere)
theta = np.linspace(0, 2 * np.pi, 100)
# Generate points along the boundary using the spherical law of cosines
boundary_lats = []
boundary_lons = []
for az in theta:
# Calculate the point at distance cap_radius_rad from center in direction az
# Using the spherical law of sines and cosines
slat = np.sin(lat_center_rad) * np.cos(cap_radius_rad) + np.cos(
lat_center_rad
) * np.sin(cap_radius_rad) * np.cos(az)
slat = np.clip(slat, -1.0, 1.0)
boundary_lat = np.arcsin(slat)
dlon = np.arctan2(
np.sin(az) * np.sin(cap_radius_rad) * np.cos(lat_center_rad),
np.cos(cap_radius_rad) - np.sin(lat_center_rad) * np.sin(boundary_lat),
)
boundary_lon = lon_center_rad + dlon
# Ensure longitude is within [-pi, pi] for Mollweide projection
boundary_lon = ((boundary_lon + np.pi) % (2 * np.pi)) - np.pi
boundary_lats.append(boundary_lat)
boundary_lons.append(boundary_lon)
# For large cap radii, the boundary may cross the edge of the projection
# and need to be drawn as multiple segments to appear correctly
# Detect jumps in longitude (which indicate edge crossings)
lon_diffs = np.abs(np.diff(boundary_lons))
jump_indices = np.where(lon_diffs > np.pi)[0]
if len(jump_indices) > 0:
# We have discontinuities - draw the boundary in segments
segments = []
start_idx = 0
# Add jump indices and the last point to complete all segments
all_indices = list(jump_indices) + [len(boundary_lons) - 1]
for end_idx in all_indices:
segment = (
boundary_lons[start_idx : end_idx + 1],
boundary_lats[start_idx : end_idx + 1],
)
segments.append(segment)
start_idx = end_idx + 1
# Draw each segment
for segment_lons, segment_lats in segments:
ax.plot(segment_lons, segment_lats, "r-", linewidth=2, alpha=0.7)
else:
# No discontinuities - draw the whole boundary at once
ax.plot(boundary_lons, boundary_lats, "r-", linewidth=2, alpha=0.7)
# Plot the center of the cap
ax.plot(lon_center_rad, lat_center_rad, "rx", markersize=10)
ax.grid(True, alpha=0.3)
tick_formatter = ticker.FuncFormatter(lambda x, pos: f"{np.degrees(x):.0f}°")
# Set up longitude (x) ticks every 15 degrees and latitude (y) ticks every 10 degrees -
# longitude ranges from -180 to +180 and latitude ranges from -90 to +90.
ax.xaxis.set_major_locator(ticker.MultipleLocator(np.radians(15)))
ax.xaxis.set_major_formatter(tick_formatter)
ax.yaxis.set_major_locator(ticker.MultipleLocator(np.radians(10)))
ax.yaxis.set_major_formatter(tick_formatter)
if title is not None:
ax.set_title(title)
return fig
def main():
args = parse_arguments()
# Sets the correct RNG, critical so reference vectors are consistent with training. Also enables
# compilation cache.
setup_jax_for_training()
print(f"Loading checkpoint from {args.checkpoint_dir}")
params, step, mdl = load_params(args.checkpoint_dir, args.step)
print(f"Using checkpoint step: {step}")
print(mdl.tabulate(jax.random.PRNGKey(0), *mdl.dummy_inputs()))
# Check if we have a 3D model (required for Mollweide projection)
if mdl.domain_dim != 3:
print(
f"Error: Model domain dimension is {mdl.domain_dim}, but must be 3 for Mollweide projection."
)
return
rng = jax.random.PRNGKey(args.seed)
centers_rng, samples_rng = jax.random.split(rng)
print(f"Generating {args.n_samples} samples...")
use_cap_conditioning = all(
param is not None for param in [args.latitude, args.longitude, args.cap_radius]
)
if use_cap_conditioning:
cap_center = latlon_to_unit_vector(args.longitude, args.latitude)
max_cos_distance = 1 - np.cos(np.radians(args.cap_radius))
print(
f"Using cap conditioning with center at lat={args.latitude}°, lon={args.longitude}°"
)
print(
f"Cap angular radius: {args.cap_radius}° (cosine distance: {max_cos_distance:.6f})"
)
if mdl.weighting_function not in (
WeightingFunction.CAP_INDICATOR,
WeightingFunction.SMOOTHED_CAP_INDICATOR,
):
raise ValueError(
f"Cap conditioning requires CAP_INDICATOR or SMOOTHED_CAP_INDICATOR "
f"weighting function, but model has {mdl.weighting_function.value}"
)
# Validate d_max constraints for CAP base distribution
if mdl.base_distribution == BaseDistribution.CAP:
if max_cos_distance > 1.0:
if max_cos_distance < 2.0:
raise ValueError(
f"Invalid d_max {max_cos_distance:.3f}: for CAP models, d_max must be <= 1.0 or exactly 2.0, not between 1.0 and 2.0"
)
elif max_cos_distance == 2.0:
print(
"CAP base distribution with d_max = 2.0. Using hemisphere sampling strategy."
)
samples = sample_full_sphere(
mdl,
params,
samples_rng,
args.n_samples,
args.batch_size,
args.n_steps,
)
else:
raise ValueError(
f"Invalid d_max {max_cos_distance:.3f}: must be <= 2.0"
)
else:
# Direct sampling for d_max <= 1.0
cap_centers = jnp.tile(cap_center, (args.n_samples, 1))
cap_d_maxes = jnp.full((args.n_samples,), max_cos_distance)
weighting_function_params = (cap_centers, cap_d_maxes)
samples = sample_loop(
mdl,
params,
samples_rng,
weighting_function_params,
args.n_samples,
args.batch_size,
args.n_steps,
)
else:
# Non-CAP base distribution - use weighting function parameters
cap_centers = jnp.tile(cap_center, (args.n_samples, 1))
cap_d_maxes = jnp.full((args.n_samples,), max_cos_distance)
weighting_function_params = (cap_centers, cap_d_maxes)
samples = sample_loop(
mdl,
params,
samples_rng,
weighting_function_params,
args.n_samples,
args.batch_size,
args.n_steps,
)
else:
# Unconditioned sampling - use d_max=2.0 (full sphere)
print("Sampling from full sphere")
samples = sample_full_sphere(
mdl, params, centers_rng, args.n_samples, args.batch_size, args.n_steps
)
samples = jax.device_get(samples)
title = args.title
if use_cap_conditioning and title is None:
title = f"Samples with cap center at lat={args.latitude}°, lon={args.longitude}°, radius={args.cap_radius}°"
print("Creating Mollweide projection visualization...")
if use_cap_conditioning:
fig = create_mollweide_with_cap(
samples,
cap_center_latlon=(args.latitude, args.longitude),
cap_radius_deg=args.cap_radius,
title=title,
)
else:
fig = create_mollweide_projection_figure(samples, title=title)
print(f"Saving visualization to {args.output}")
fig.savefig(args.output, bbox_inches="tight")
plt.close(fig)
print("Done!")
if __name__ == "__main__":
main()