1717# PCA visualization for the image embeddings
1818#
1919
20- def compute_pca (embeddings : np .ndarray ) -> np .ndarray :
20+ def compute_pca (embeddings : np .ndarray , n_components : int = 3 , as_rgb : bool = True ) -> np .ndarray :
2121 """Compute the pca projection of the embeddings to visualize them as RGB image.
2222
2323 Args:
2424 embeddings: The embeddings. For example predicted by the SAM image encoder.
25+ n_components: The number of PCA components to use for dimensionality reduction.
26+ as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
2527
2628 Returns:
2729 PCA of the embeddings, mapped to the pixels.
2830 """
2931 if embeddings .ndim == 4 :
30- pca = embedding_pca (embeddings .squeeze ()).transpose ((1 , 2 , 0 ))
32+ pca = embedding_pca (embeddings .squeeze (), n_components = n_components , as_rgb = as_rgb ).transpose ((1 , 2 , 0 ))
3133 elif embeddings .ndim == 5 :
3234 pca = []
3335 for embed in embeddings :
34- vis = embedding_pca (embed .squeeze ()).transpose ((1 , 2 , 0 ))
36+ vis = embedding_pca (embed .squeeze (), n_components = n_components , as_rgb = as_rgb ).transpose ((1 , 2 , 0 ))
3537 pca .append (vis )
3638 pca = np .stack (pca )
3739 else :
@@ -53,10 +55,10 @@ def _get_crop(embed_shape, shape):
5355 return crop
5456
5557
56- def _project_embeddings (embeddings , shape , apply_crop = True ):
58+ def _project_embeddings (embeddings , shape , apply_crop = True , n_components = 3 , as_rgb = True ):
5759 assert embeddings .ndim == len (shape ) + 2 , f"{ embeddings .shape } , { shape } "
5860
59- embedding_vis = compute_pca (embeddings )
61+ embedding_vis = compute_pca (embeddings , n_components = n_components , as_rgb = as_rgb )
6062 if not apply_crop :
6163 pass
6264 elif len (shape ) == 2 :
@@ -107,7 +109,7 @@ def resize_shape(shape):
107109 return np .concatenate ([resize (arr , resize_shape (arr .shape )) for arr in arrays ], axis = axis )
108110
109111
110- def _project_tiled_embeddings (image_embeddings ):
112+ def _project_tiled_embeddings (image_embeddings , n_components , as_rgb ):
111113 features = image_embeddings ["features" ]
112114 tile_shape , halo , shape = features .attrs ["tile_shape" ], features .attrs ["halo" ], features .attrs ["shape" ]
113115 tiling = blocking ([0 , 0 ], shape , tile_shape )
@@ -141,30 +143,34 @@ def _project_tiled_embeddings(image_embeddings):
141143
142144 if features ["0" ].ndim == 5 :
143145 shape = (features ["0" ].shape [0 ],) + tuple (shape )
144- embedding_vis , scale = _project_embeddings (embeds , shape , apply_crop = False )
146+ embedding_vis , scale = _project_embeddings (
147+ embeds , shape , n_components = n_components , as_rgb = as_rgb , apply_crop = False
148+ )
145149 return embedding_vis , scale
146150
147151
148152def project_embeddings_for_visualization (
149- image_embeddings : ImageEmbeddings
153+ image_embeddings : ImageEmbeddings , n_components : int = 3 , as_rgb : bool = True ,
150154) -> Tuple [np .ndarray , Tuple [float , ...]]:
151155 """Project image embeddings to pixel-wise PCA.
152156
153157 Args:
154158 image_embeddings: The image embeddings.
159+ n_components: The number of PCA components to use for dimensionality reduction.
160+ as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
155161
156162 Returns:
157163 The PCA of the embeddings.
158164 The scale factor for resizing to the original image size.
159165 """
160166 is_tiled = image_embeddings ["input_size" ] is None
161167 if is_tiled :
162- embedding_vis , scale = _project_tiled_embeddings (image_embeddings )
168+ embedding_vis , scale = _project_tiled_embeddings (image_embeddings , n_components , as_rgb )
163169 else :
164170 embeddings = image_embeddings ["features" ]
165171 shape = tuple (image_embeddings ["original_size" ])
166172 if embeddings .ndim == 5 :
167173 shape = (embeddings .shape [0 ],) + shape
168- embedding_vis , scale = _project_embeddings (embeddings , shape )
174+ embedding_vis , scale = _project_embeddings (embeddings , shape , n_components = n_components , as_rgb = as_rgb )
169175
170176 return embedding_vis , scale
0 commit comments