1+ from fillm .run .model import *
2+ import torch .nn .functional as F
3+ import torch
4+ import numpy as np
5+ import h5py
6+ from PIL import Image as im
7+ import matplotlib .pyplot as plt
8+ import matplotlib .patches as patches
9+
10+ from scipy .stats import binned_statistic_2d
11+
12+ def load_model_from_ckpt (ckpt_path : str ):
13+ """
14+ Load a model from a checkpoint.
15+ """
16+ if Path (ckpt_path ).is_dir ():
17+ ckpt_path = Path (ckpt_path ) / "ckpt.pt"
18+
19+ chkpt = torch .load (ckpt_path )
20+ config = chkpt ["config" ]
21+ state_dict = chkpt ["model" ]
22+ model_name = config ["model" ]['kind' ]
23+ model_keys = get_model_keys (model_name )
24+
25+ model_args = {k : config ['model' ][k ] for k in model_keys }
26+
27+ model_ctr , config_cls = model_registry [model_name ]
28+ model_config = config_cls (** model_args )
29+ model_ = model_ctr (model_config )
30+ model_ .load_state_dict (state_dict )
31+
32+ return {"model" : model_ , "config" : config }
33+
34+ def forward (
35+ self , x : torch .Tensor , y : Optional [torch .Tensor ] = None
36+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
37+ device = x .device
38+ t = x .shape [1 ]
39+
40+ # find the mask locations
41+ locs = x != y
42+
43+ if t > self .config .block_size :
44+ raise ValueError (
45+ f"Cannot forward sequence of length { t } , "
46+ f"block size is only { self .config .block_size } "
47+ )
48+ pos = torch .arange (0 , t , dtype = torch .long , device = device ) # shape (t)
49+
50+ # forward the GPT model itself
51+ data_emb = self .data_embed (x ) # to shape (b, t, embedding_dim)
52+ pos_emb = self .position_embed (pos ) # to shape (t, embedding_dim)
53+
54+ x = self .dropout (data_emb + pos_emb )
55+ embeddings = []
56+ for block in self .blocks :
57+ x = block (x )
58+ embeddings .append (x .detach ().clone ())
59+ x = self .final_layernorm (x )
60+
61+ preds = self .head (x )
62+ if y is not None :
63+ # if we are given some desired targets also calculate the loss
64+ locs = locs .type_as (preds )
65+ loss = F .mse_loss (preds * locs , y * locs , reduction = "mean" ) / locs .mean ()
66+ else :
67+ loss = None
68+
69+ return {"preds" : preds , "loss" : loss , "embeddings" : embeddings }
70+
71+ def slice (x , section_length = 10 , overlap = 5 ):
72+
73+ start_indices = np .arange (0 , x .shape [1 ] - overlap , section_length - overlap )
74+ sections = [x [:,start :start + section_length ].transpose (1 ,2 ) for start in start_indices ]
75+
76+ # If the last section is not of length 'section_length', you can decide whether to keep or discard it
77+ if sections [- 1 ].shape [1 ] < section_length :
78+ sections .pop (- 1 ) # Discard the last section
79+
80+ return torch .cat (sections , 1 )
81+
82+
83+ def fnc (x ):
84+ std , mean = x .std (1 , keepdim = True ).clip_ (0.2 ), x .mean (1 , keepdim = True )
85+ x = (x - mean ) / std
86+ x = slice (x , 20 , 10 )
87+ x = F .pad (x , pad = (2 , 0 , 1 , 0 ), mode = 'constant' , value = 0 )
88+ x [:,0 ,0 ] = (mean .squeeze ()- 2 )/ 2
89+ x [:,0 ,1 ] = (std .squeeze ()- 2 )/ 8
90+
91+ return x
92+
93+ def sdss_rgb (imgs , bands , scales = None ,
94+ m = 0.02 ):
95+ import numpy as np
96+ rgbscales = {'u' : (2 ,1.5 ), #1.0,
97+ 'g' : (2 ,2.5 ),
98+ 'r' : (1 ,1.5 ),
99+ 'i' : (0 ,1.0 ),
100+ 'z' : (0 ,0.4 ), #0.3
101+ }
102+ if scales is not None :
103+ rgbscales .update (scales )
104+
105+ I = 0
106+ for img ,band in zip (imgs , bands ):
107+ plane ,scale = rgbscales [band ]
108+ img = np .maximum (0 , img * scale + m )
109+ I = I + img
110+ I /= len (bands )
111+
112+ # b,g,r = [rimg * rgbscales[b] for rimg,b in zip(imgs, bands)]
113+ # r = np.maximum(0, r + m)
114+ # g = np.maximum(0, g + m)
115+ # b = np.maximum(0, b + m)
116+ # I = (r+g+b)/3.
117+ Q = 20
118+ fI = np .arcsinh (Q * I ) / np .sqrt (Q )
119+ I += (I == 0. ) * 1e-6
120+ H ,W = I .shape
121+ rgb = np .zeros ((H ,W ,3 ), np .float32 )
122+ for img ,band in zip (imgs , bands ):
123+ plane ,scale = rgbscales [band ]
124+ rgb [:,:,plane ] = (img * scale + m ) * fI / I
125+
126+ # R = fI * r / I
127+ # G = fI * g / I
128+ # B = fI * b / I
129+ # # maxrgb = reduce(np.maximum, [R,G,B])
130+ # # J = (maxrgb > 1.)
131+ # # R[J] = R[J]/maxrgb[J]
132+ # # G[J] = G[J]/maxrgb[J]
133+ # # B[J] = B[J]/maxrgb[J]
134+ # rgb = np.dstack((R,G,B))
135+ rgb = np .clip (rgb , 0 , 1 )
136+ return rgb
137+
138+ def dr2_rgb (rimgs , bands , ** ignored ):
139+ return sdss_rgb (rimgs , bands , scales = dict (g = (2 ,6.0 ), r = (1 ,3.4 ), z = (0 ,2.2 )), m = 0.03 )
140+
141+ # Code borrowed from https://github.com/georgestein/ssl-legacysurvey
142+ def scatter_plot_as_images (z_emb , images , nx = 8 , ny = 8 , npix_show = 96 , iseed = 13579 , display_image = True ):
143+ """Sample points from scatter plot and display as their original galaxy image
144+
145+ Parameters
146+ ----------
147+ DDL : class instance
148+ DecalsDataLoader class instance
149+ z_emb: array
150+ (N, 2) array of the galaxies location in some compressed space.
151+ If second axis has a dimensionality greater than 2 we only consider the leading two components.
152+ """
153+ z_emb = z_emb [:, :2 ] # keep only first two dimensions
154+
155+ nplt = nx * ny
156+
157+ img_full = np .zeros ((ny * npix_show , nx * npix_show , 3 )) + 255 #, dtype=np.uint8) + 255
158+
159+ xmin = z_emb [:,0 ].min ()
160+ xmax = z_emb [:,0 ].max ()
161+ ymin = z_emb [:,1 ].min ()
162+ ymax = z_emb [:,1 ].max ()
163+
164+ dz_emb = 0.25
165+ dx_cent = z_emb [:,0 ].mean ()
166+ dy_cent = z_emb [:,1 ].mean ()
167+
168+ dx_cent = 10.0
169+ dy_cent = 7.0
170+
171+ # xmin = dx_cent - dz_emb
172+ # xmax = dx_cent + dz_emb
173+ # ymin = dy_cent - dz_emb
174+ # ymax = dy_cent + dz_emb
175+
176+ binx = np .linspace (xmin ,xmax , nx + 1 )
177+ biny = np .linspace (ymin ,ymax , ny + 1 )
178+
179+ ret = binned_statistic_2d (z_emb [:,0 ], z_emb [:,1 ], z_emb [:,1 ], 'count' , bins = [binx , biny ], expand_binnumbers = True )
180+ z_emb_bins = ret .binnumber .T
181+
182+ inds_used = []
183+ inds_lin = np .arange (z_emb .shape [0 ])
184+
185+ # First get all indexes that will be used
186+ for ix in range (nx ):
187+ for iy in range (ny ):
188+ dm = (z_emb_bins [:,0 ]== ix ) & (z_emb_bins [:,1 ]== iy )
189+ inds = inds_lin [dm ]
190+
191+ np .random .seed (ix * nx + iy + iseed )
192+ if len (inds ) > 0 :
193+ ind_plt = np .random .choice (inds )
194+ inds_used .append (ind_plt )# inds_use[ind_plt])
195+
196+ # load in all images
197+ iimg = 0
198+
199+ # Add each image as postage stamp in desired region
200+ for ix in range (nx ):
201+ for iy in range (ny ):
202+ dm = (z_emb_bins [:,0 ] == ix ) & (z_emb_bins [:,1 ]== iy )
203+ inds = inds_lin [dm ]
204+
205+ np .random .seed (ix * nx + iy + iseed )
206+ if len (inds ) > 0 :
207+
208+ imi = images [inds [0 ]][28 :- 28 , 28 :- 28 ]
209+ img_full [iy * npix_show :(iy + 1 )* npix_show , ix * npix_show :(ix + 1 )* npix_show ] = imi
210+
211+ iimg += 1
212+
213+ if display_image :
214+ plt .figure (figsize = (nx , ny ))
215+ plt .imshow (img_full , origin = 'lower' )#, interpolation='none')
216+ plt .axis ('off' )
217+
218+ return img_full
0 commit comments