Skip to content

Commit cb92c5c

Browse files
authored
Merge pull request #12 from PolymathicAI/tutorial
Added tutorial notebook
2 parents 4af233d + 973c932 commit cb92c5c

File tree

7 files changed

+1683
-0
lines changed

7 files changed

+1683
-0
lines changed

notebooks/tutorial/AstroCLIPTutorial.ipynb

Lines changed: 1465 additions & 0 deletions
Large diffs are not rendered by default.
1.42 MB
Loading
148 KB
Binary file not shown.
144 KB
Binary file not shown.
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from fillm.run.model import *
2+
import torch.nn.functional as F
3+
import torch
4+
import numpy as np
5+
import h5py
6+
from PIL import Image as im
7+
import matplotlib.pyplot as plt
8+
import matplotlib.patches as patches
9+
10+
from scipy.stats import binned_statistic_2d
11+
12+
def load_model_from_ckpt(ckpt_path: str):
13+
"""
14+
Load a model from a checkpoint.
15+
"""
16+
if Path(ckpt_path).is_dir():
17+
ckpt_path = Path(ckpt_path) / "ckpt.pt"
18+
19+
chkpt = torch.load(ckpt_path)
20+
config = chkpt["config"]
21+
state_dict = chkpt["model"]
22+
model_name = config["model"]['kind']
23+
model_keys = get_model_keys(model_name)
24+
25+
model_args = {k: config['model'][k] for k in model_keys}
26+
27+
model_ctr, config_cls = model_registry[model_name]
28+
model_config = config_cls(**model_args)
29+
model_ = model_ctr(model_config)
30+
model_.load_state_dict(state_dict)
31+
32+
return {"model": model_, "config": config}
33+
34+
def forward(
35+
self, x: torch.Tensor, y: Optional[torch.Tensor] = None
36+
) -> Tuple[torch.Tensor, torch.Tensor]:
37+
device = x.device
38+
t = x.shape[1]
39+
40+
# find the mask locations
41+
locs = x != y
42+
43+
if t > self.config.block_size:
44+
raise ValueError(
45+
f"Cannot forward sequence of length {t}, "
46+
f"block size is only {self.config.block_size}"
47+
)
48+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
49+
50+
# forward the GPT model itself
51+
data_emb = self.data_embed(x) # to shape (b, t, embedding_dim)
52+
pos_emb = self.position_embed(pos) # to shape (t, embedding_dim)
53+
54+
x = self.dropout(data_emb + pos_emb)
55+
embeddings = []
56+
for block in self.blocks:
57+
x = block(x)
58+
embeddings.append(x.detach().clone())
59+
x = self.final_layernorm(x)
60+
61+
preds = self.head(x)
62+
if y is not None:
63+
# if we are given some desired targets also calculate the loss
64+
locs = locs.type_as(preds)
65+
loss = F.mse_loss(preds * locs, y * locs, reduction="mean") / locs.mean()
66+
else:
67+
loss = None
68+
69+
return {"preds": preds, "loss": loss, "embeddings": embeddings}
70+
71+
def slice(x, section_length=10, overlap=5):
72+
73+
start_indices = np.arange(0, x.shape[1] - overlap, section_length - overlap)
74+
sections = [x[:,start:start + section_length].transpose(1,2) for start in start_indices]
75+
76+
# If the last section is not of length 'section_length', you can decide whether to keep or discard it
77+
if sections[-1].shape[1] < section_length:
78+
sections.pop(-1) # Discard the last section
79+
80+
return torch.cat(sections, 1)
81+
82+
83+
def fnc(x):
84+
std, mean = x.std(1, keepdim=True).clip_(0.2), x.mean(1, keepdim=True)
85+
x = (x - mean) / std
86+
x = slice(x, 20, 10)
87+
x = F.pad(x, pad=(2, 0, 1, 0), mode='constant', value=0)
88+
x[:,0,0] = (mean.squeeze()-2)/2
89+
x[:,0,1] = (std.squeeze()-2)/8
90+
91+
return x
92+
93+
def sdss_rgb(imgs, bands, scales=None,
94+
m = 0.02):
95+
import numpy as np
96+
rgbscales = {'u': (2,1.5), #1.0,
97+
'g': (2,2.5),
98+
'r': (1,1.5),
99+
'i': (0,1.0),
100+
'z': (0,0.4), #0.3
101+
}
102+
if scales is not None:
103+
rgbscales.update(scales)
104+
105+
I = 0
106+
for img,band in zip(imgs, bands):
107+
plane,scale = rgbscales[band]
108+
img = np.maximum(0, img * scale + m)
109+
I = I + img
110+
I /= len(bands)
111+
112+
# b,g,r = [rimg * rgbscales[b] for rimg,b in zip(imgs, bands)]
113+
# r = np.maximum(0, r + m)
114+
# g = np.maximum(0, g + m)
115+
# b = np.maximum(0, b + m)
116+
# I = (r+g+b)/3.
117+
Q = 20
118+
fI = np.arcsinh(Q * I) / np.sqrt(Q)
119+
I += (I == 0.) * 1e-6
120+
H,W = I.shape
121+
rgb = np.zeros((H,W,3), np.float32)
122+
for img,band in zip(imgs, bands):
123+
plane,scale = rgbscales[band]
124+
rgb[:,:,plane] = (img * scale + m) * fI / I
125+
126+
# R = fI * r / I
127+
# G = fI * g / I
128+
# B = fI * b / I
129+
# # maxrgb = reduce(np.maximum, [R,G,B])
130+
# # J = (maxrgb > 1.)
131+
# # R[J] = R[J]/maxrgb[J]
132+
# # G[J] = G[J]/maxrgb[J]
133+
# # B[J] = B[J]/maxrgb[J]
134+
# rgb = np.dstack((R,G,B))
135+
rgb = np.clip(rgb, 0, 1)
136+
return rgb
137+
138+
def dr2_rgb(rimgs, bands, **ignored):
139+
return sdss_rgb(rimgs, bands, scales=dict(g=(2,6.0), r=(1,3.4), z=(0,2.2)), m=0.03)
140+
141+
# Code borrowed from https://github.com/georgestein/ssl-legacysurvey
142+
def scatter_plot_as_images(z_emb, images, nx=8, ny=8, npix_show=96, iseed=13579, display_image=True):
143+
"""Sample points from scatter plot and display as their original galaxy image
144+
145+
Parameters
146+
----------
147+
DDL : class instance
148+
DecalsDataLoader class instance
149+
z_emb: array
150+
(N, 2) array of the galaxies location in some compressed space.
151+
If second axis has a dimensionality greater than 2 we only consider the leading two components.
152+
"""
153+
z_emb = z_emb[:, :2] # keep only first two dimensions
154+
155+
nplt = nx*ny
156+
157+
img_full = np.zeros((ny*npix_show, nx*npix_show, 3)) + 255#, dtype=np.uint8) + 255
158+
159+
xmin = z_emb[:,0].min()
160+
xmax = z_emb[:,0].max()
161+
ymin = z_emb[:,1].min()
162+
ymax = z_emb[:,1].max()
163+
164+
dz_emb = 0.25
165+
dx_cent = z_emb[:,0].mean()
166+
dy_cent = z_emb[:,1].mean()
167+
168+
dx_cent = 10.0
169+
dy_cent = 7.0
170+
171+
# xmin = dx_cent - dz_emb
172+
# xmax = dx_cent + dz_emb
173+
# ymin = dy_cent - dz_emb
174+
# ymax = dy_cent + dz_emb
175+
176+
binx = np.linspace(xmin,xmax, nx+1)
177+
biny = np.linspace(ymin,ymax, ny+1)
178+
179+
ret = binned_statistic_2d(z_emb[:,0], z_emb[:,1], z_emb[:,1], 'count', bins=[binx, biny], expand_binnumbers=True)
180+
z_emb_bins = ret.binnumber.T
181+
182+
inds_used = []
183+
inds_lin = np.arange(z_emb.shape[0])
184+
185+
# First get all indexes that will be used
186+
for ix in range(nx):
187+
for iy in range(ny):
188+
dm = (z_emb_bins[:,0]==ix) & (z_emb_bins[:,1]==iy)
189+
inds = inds_lin[dm]
190+
191+
np.random.seed(ix*nx+iy+iseed)
192+
if len(inds) > 0:
193+
ind_plt = np.random.choice(inds)
194+
inds_used.append(ind_plt)# inds_use[ind_plt])
195+
196+
# load in all images
197+
iimg = 0
198+
199+
# Add each image as postage stamp in desired region
200+
for ix in range(nx):
201+
for iy in range(ny):
202+
dm = (z_emb_bins[:,0] == ix) & (z_emb_bins[:,1]==iy)
203+
inds = inds_lin[dm]
204+
205+
np.random.seed(ix*nx+iy+iseed)
206+
if len(inds) > 0:
207+
208+
imi = images[inds[0]][28:-28, 28:-28]
209+
img_full[iy*npix_show:(iy+1)*npix_show, ix*npix_show:(ix+1)*npix_show] = imi
210+
211+
iimg += 1
212+
213+
if display_image:
214+
plt.figure(figsize=(nx, ny))
215+
plt.imshow(img_full, origin='lower')#, interpolation='none')
216+
plt.axis('off')
217+
218+
return img_full
155 KB
Binary file not shown.
156 KB
Binary file not shown.

0 commit comments

Comments
 (0)