1- from typing import List , Optional , Tuple , Union
1+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2+ # Copyright 2024-2025 The HuggingFace Team. All rights reserved.
3+ #
4+ # Licensed under the Apache License, Version 2.0 (the "License");
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+ # --------------------------------------------------------------------------
16+ # More information and citation instructions are available on the
17+ # Marigold project website: https://marigoldcomputervision.github.io
18+ # --------------------------------------------------------------------------
19+ from typing import Any , Dict , List , Optional , Tuple , Union
220
321import numpy as np
422import PIL
@@ -379,7 +397,7 @@ def visualize_depth(
379397 val_min : float = 0.0 ,
380398 val_max : float = 1.0 ,
381399 color_map : str = "Spectral" ,
382- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
400+ ) -> List [PIL .Image .Image ]:
383401 """
384402 Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`.
385403
@@ -391,7 +409,7 @@ def visualize_depth(
391409 color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel
392410 depth prediction into colored representation.
393411
394- Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with depth maps visualization.
412+ Returns: `List[PIL.Image.Image]` with depth maps visualization.
395413 """
396414 if val_max <= val_min :
397415 raise ValueError (f"Invalid values range: [{ val_min } , { val_max } ]." )
@@ -436,7 +454,7 @@ def export_depth_to_16bit_png(
436454 depth : Union [np .ndarray , torch .Tensor , List [np .ndarray ], List [torch .Tensor ]],
437455 val_min : float = 0.0 ,
438456 val_max : float = 1.0 ,
439- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
457+ ) -> List [PIL .Image .Image ]:
440458 def export_depth_to_16bit_png_one (img , idx = None ):
441459 prefix = "Depth" + (f"[{ idx } ]" if idx else "" )
442460 if not isinstance (img , np .ndarray ) and not torch .is_tensor (img ):
@@ -478,7 +496,7 @@ def visualize_normals(
478496 flip_x : bool = False ,
479497 flip_y : bool = False ,
480498 flip_z : bool = False ,
481- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
499+ ) -> List [PIL .Image .Image ]:
482500 """
483501 Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`.
484502
@@ -492,7 +510,7 @@ def visualize_normals(
492510 flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference.
493511 Default direction is facing the observer.
494512
495- Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with surface normals visualization.
513+ Returns: `List[PIL.Image.Image]` with surface normals visualization.
496514 """
497515 flip_vec = None
498516 if any ((flip_x , flip_y , flip_z )):
@@ -528,6 +546,99 @@ def visualize_normals_one(img, idx=None):
528546 else :
529547 raise ValueError (f"Unexpected input type: { type (normals )} " )
530548
549+ @staticmethod
550+ def visualize_intrinsics (
551+ prediction : Union [
552+ np .ndarray ,
553+ torch .Tensor ,
554+ List [np .ndarray ],
555+ List [torch .Tensor ],
556+ ],
557+ target_properties : Dict [str , Any ],
558+ color_map : Union [str , Dict [str , str ]] = "binary" ,
559+ ) -> List [Dict [str , PIL .Image .Image ]]:
560+ """
561+ Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`.
562+
563+ Args:
564+ prediction (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
565+ Intrinsic image decomposition.
566+ target_properties (`Dict[str, Any]`):
567+ Decomposition properties. Expected entries: `target_names: List[str]` and a dictionary with keys
568+ `prediction_space: str`, `sub_target_names: List[Union[str, Null]]` (must have 3 entries, null for
569+ missing modalities), `up_to_scale: bool`, one for each target and sub-target.
570+ color_map (`Union[str, Dict[str, str]]`, *optional*, defaults to `"Spectral"`):
571+ Color map used to convert a single-channel predictions into colored representations. When a dictionary
572+ is passed, each modality can be colored with its own color map.
573+
574+ Returns: `List[Dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization.
575+ """
576+ if "target_names" not in target_properties :
577+ raise ValueError ("Missing `target_names` in target_properties" )
578+ if not isinstance (color_map , str ) and not (
579+ isinstance (color_map , dict )
580+ and all (isinstance (k , str ) and isinstance (v , str ) for k , v in color_map .items ())
581+ ):
582+ raise ValueError ("`color_map` must be a string or a dictionary of strings" )
583+ n_targets = len (target_properties ["target_names" ])
584+
585+ def visualize_targets_one (images , idx = None ):
586+ # img: [T, 3, H, W]
587+ out = {}
588+ for target_name , img in zip (target_properties ["target_names" ], images ):
589+ img = img .permute (1 , 2 , 0 ) # [H, W, 3]
590+ prediction_space = target_properties [target_name ].get ("prediction_space" , "srgb" )
591+ if prediction_space == "stack" :
592+ sub_target_names = target_properties [target_name ]["sub_target_names" ]
593+ if len (sub_target_names ) != 3 or any (
594+ not (isinstance (s , str ) or s is None ) for s in sub_target_names
595+ ):
596+ raise ValueError (f"Unexpected target sub-names { sub_target_names } in { target_name } " )
597+ for i , sub_target_name in enumerate (sub_target_names ):
598+ if sub_target_name is None :
599+ continue
600+ sub_img = img [:, :, i ]
601+ sub_prediction_space = target_properties [sub_target_name ].get ("prediction_space" , "srgb" )
602+ if sub_prediction_space == "linear" :
603+ sub_up_to_scale = target_properties [sub_target_name ].get ("up_to_scale" , False )
604+ if sub_up_to_scale :
605+ sub_img = sub_img / max (sub_img .max ().item (), 1e-6 )
606+ sub_img = sub_img ** (1 / 2.2 )
607+ cmap_name = (
608+ color_map if isinstance (color_map , str ) else color_map .get (sub_target_name , "binary" )
609+ )
610+ sub_img = MarigoldImageProcessor .colormap (sub_img , cmap = cmap_name , bytes = True )
611+ sub_img = PIL .Image .fromarray (sub_img .cpu ().numpy ())
612+ out [sub_target_name ] = sub_img
613+ elif prediction_space == "linear" :
614+ up_to_scale = target_properties [target_name ].get ("up_to_scale" , False )
615+ if up_to_scale :
616+ img = img / max (img .max ().item (), 1e-6 )
617+ img = img ** (1 / 2.2 )
618+ elif prediction_space == "srgb" :
619+ pass
620+ img = (img * 255 ).to (dtype = torch .uint8 , device = "cpu" ).numpy ()
621+ img = PIL .Image .fromarray (img )
622+ out [target_name ] = img
623+ return out
624+
625+ if prediction is None or isinstance (prediction , list ) and any (o is None for o in prediction ):
626+ raise ValueError ("Input prediction is `None`" )
627+ if isinstance (prediction , (np .ndarray , torch .Tensor )):
628+ prediction = MarigoldImageProcessor .expand_tensor_or_array (prediction )
629+ if isinstance (prediction , np .ndarray ):
630+ prediction = MarigoldImageProcessor .numpy_to_pt (prediction ) # [N*T,3,H,W]
631+ if not (prediction .ndim == 4 and prediction .shape [1 ] == 3 and prediction .shape [0 ] % n_targets == 0 ):
632+ raise ValueError (f"Unexpected input shape={ prediction .shape } , expecting [N*T,3,H,W]." )
633+ N_T , _ , H , W = prediction .shape
634+ N = N_T // n_targets
635+ prediction = prediction .reshape (N , n_targets , 3 , H , W )
636+ return [visualize_targets_one (img , idx ) for idx , img in enumerate (prediction )]
637+ elif isinstance (prediction , list ):
638+ return [visualize_targets_one (img , idx ) for idx , img in enumerate (prediction )]
639+ else :
640+ raise ValueError (f"Unexpected input type: { type (prediction )} " )
641+
531642 @staticmethod
532643 def visualize_uncertainty (
533644 uncertainty : Union [
@@ -537,24 +648,26 @@ def visualize_uncertainty(
537648 List [torch .Tensor ],
538649 ],
539650 saturation_percentile = 95 ,
540- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
651+ ) -> List [PIL .Image .Image ]:
541652 """
542- Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`.
653+ Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or
654+ `MarigoldIntrinsicsPipeline`.
543655
544656 Args:
545657 uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
546658 Uncertainty maps.
547659 saturation_percentile (`int`, *optional*, defaults to `95`):
548660 Specifies the percentile uncertainty value visualized with maximum intensity.
549661
550- Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with uncertainty visualization.
662+ Returns: `List[PIL.Image.Image]` with uncertainty visualization.
551663 """
552664
553665 def visualize_uncertainty_one (img , idx = None ):
554666 prefix = "Uncertainty" + (f"[{ idx } ]" if idx else "" )
555667 if img .min () < 0 :
556- raise ValueError (f"{ prefix } : unexected data range, min={ img .min ()} ." )
557- img = img .squeeze (0 ).cpu ().numpy ()
668+ raise ValueError (f"{ prefix } : unexpected data range, min={ img .min ()} ." )
669+ img = img .permute (1 , 2 , 0 ) # [H,W,C]
670+ img = img .squeeze (2 ).cpu ().numpy () # [H,W] or [H,W,3]
558671 saturation_value = np .percentile (img , saturation_percentile )
559672 img = np .clip (img * 255 / saturation_value , 0 , 255 )
560673 img = img .astype (np .uint8 )
@@ -566,9 +679,9 @@ def visualize_uncertainty_one(img, idx=None):
566679 if isinstance (uncertainty , (np .ndarray , torch .Tensor )):
567680 uncertainty = MarigoldImageProcessor .expand_tensor_or_array (uncertainty )
568681 if isinstance (uncertainty , np .ndarray ):
569- uncertainty = MarigoldImageProcessor .numpy_to_pt (uncertainty ) # [N,1 ,H,W]
570- if not (uncertainty .ndim == 4 and uncertainty .shape [1 ] == 1 ):
571- raise ValueError (f"Unexpected input shape={ uncertainty .shape } , expecting [N,1 ,H,W]." )
682+ uncertainty = MarigoldImageProcessor .numpy_to_pt (uncertainty ) # [N,C ,H,W]
683+ if not (uncertainty .ndim == 4 and uncertainty .shape [1 ] in ( 1 , 3 ) ):
684+ raise ValueError (f"Unexpected input shape={ uncertainty .shape } , expecting [N,C ,H,W] with C in (1,3) ." )
572685 return [visualize_uncertainty_one (img , idx ) for idx , img in enumerate (uncertainty )]
573686 elif isinstance (uncertainty , list ):
574687 return [visualize_uncertainty_one (img , idx ) for idx , img in enumerate (uncertainty )]
0 commit comments