1- from  typing  import  List , Optional , Tuple , Union 
1+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. 
2+ # Copyright 2024-2025 The HuggingFace Team. All rights reserved. 
3+ # 
4+ # Licensed under the Apache License, Version 2.0 (the "License"); 
5+ # you may not use this file except in compliance with the License. 
6+ # You may obtain a copy of the License at 
7+ # 
8+ #     http://www.apache.org/licenses/LICENSE-2.0 
9+ # 
10+ # Unless required by applicable law or agreed to in writing, software 
11+ # distributed under the License is distributed on an "AS IS" BASIS, 
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
13+ # See the License for the specific language governing permissions and 
14+ # limitations under the License. 
15+ # -------------------------------------------------------------------------- 
16+ # More information and citation instructions are available on the 
17+ # Marigold project website: https://marigoldcomputervision.github.io 
18+ # -------------------------------------------------------------------------- 
19+ from  typing  import  Any , Dict , List , Optional , Tuple , Union 
220
321import  numpy  as  np 
422import  PIL 
@@ -379,7 +397,7 @@ def visualize_depth(
379397        val_min : float  =  0.0 ,
380398        val_max : float  =  1.0 ,
381399        color_map : str  =  "Spectral" ,
382-     ) ->  Union [ PIL . Image . Image ,  List [PIL .Image .Image ] ]:
400+     ) ->  List [PIL .Image .Image ]:
383401        """ 
384402        Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. 
385403
@@ -391,7 +409,7 @@ def visualize_depth(
391409            color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel 
392410                      depth prediction into colored representation. 
393411
394-         Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with depth maps visualization. 
412+         Returns: `List[PIL.Image.Image]` with depth maps visualization. 
395413        """ 
396414        if  val_max  <=  val_min :
397415            raise  ValueError (f"Invalid values range: [{ val_min } { val_max }  )
@@ -436,7 +454,7 @@ def export_depth_to_16bit_png(
436454        depth : Union [np .ndarray , torch .Tensor , List [np .ndarray ], List [torch .Tensor ]],
437455        val_min : float  =  0.0 ,
438456        val_max : float  =  1.0 ,
439-     ) ->  Union [ PIL . Image . Image ,  List [PIL .Image .Image ] ]:
457+     ) ->  List [PIL .Image .Image ]:
440458        def  export_depth_to_16bit_png_one (img , idx = None ):
441459            prefix  =  "Depth"  +  (f"[{ idx }   if  idx  else  "" )
442460            if  not  isinstance (img , np .ndarray ) and  not  torch .is_tensor (img ):
@@ -478,7 +496,7 @@ def visualize_normals(
478496        flip_x : bool  =  False ,
479497        flip_y : bool  =  False ,
480498        flip_z : bool  =  False ,
481-     ) ->  Union [ PIL . Image . Image ,  List [PIL .Image .Image ] ]:
499+     ) ->  List [PIL .Image .Image ]:
482500        """ 
483501        Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. 
484502
@@ -492,7 +510,7 @@ def visualize_normals(
492510            flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. 
493511                      Default direction is facing the observer. 
494512
495-         Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with surface normals visualization. 
513+         Returns: `List[PIL.Image.Image]` with surface normals visualization. 
496514        """ 
497515        flip_vec  =  None 
498516        if  any ((flip_x , flip_y , flip_z )):
@@ -528,6 +546,99 @@ def visualize_normals_one(img, idx=None):
528546        else :
529547            raise  ValueError (f"Unexpected input type: { type (normals )}  )
530548
549+     @staticmethod  
550+     def  visualize_intrinsics (
551+         prediction : Union [
552+             np .ndarray ,
553+             torch .Tensor ,
554+             List [np .ndarray ],
555+             List [torch .Tensor ],
556+         ],
557+         target_properties : Dict [str , Any ],
558+         color_map : Union [str , Dict [str , str ]] =  "binary" ,
559+     ) ->  List [Dict [str , PIL .Image .Image ]]:
560+         """ 
561+         Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`. 
562+ 
563+         Args: 
564+             prediction (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): 
565+                 Intrinsic image decomposition. 
566+             target_properties (`Dict[str, Any]`): 
567+                 Decomposition properties. Expected entries: `target_names: List[str]` and a dictionary with keys 
568+                 `prediction_space: str`, `sub_target_names: List[Union[str, Null]]` (must have 3 entries, null for 
569+                 missing modalities), `up_to_scale: bool`, one for each target and sub-target. 
570+             color_map (`Union[str, Dict[str, str]]`, *optional*, defaults to `"Spectral"`): 
571+                 Color map used to convert a single-channel predictions into colored representations. When a dictionary 
572+                 is passed, each modality can be colored with its own color map. 
573+ 
574+         Returns: `List[Dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization. 
575+         """ 
576+         if  "target_names"  not  in target_properties :
577+             raise  ValueError ("Missing `target_names` in target_properties" )
578+         if  not  isinstance (color_map , str ) and  not  (
579+             isinstance (color_map , dict )
580+             and  all (isinstance (k , str ) and  isinstance (v , str ) for  k , v  in  color_map .items ())
581+         ):
582+             raise  ValueError ("`color_map` must be a string or a dictionary of strings" )
583+         n_targets  =  len (target_properties ["target_names" ])
584+ 
585+         def  visualize_targets_one (images , idx = None ):
586+             # img: [T, 3, H, W] 
587+             out  =  {}
588+             for  target_name , img  in  zip (target_properties ["target_names" ], images ):
589+                 img  =  img .permute (1 , 2 , 0 )  # [H, W, 3] 
590+                 prediction_space  =  target_properties [target_name ].get ("prediction_space" , "srgb" )
591+                 if  prediction_space  ==  "stack" :
592+                     sub_target_names  =  target_properties [target_name ]["sub_target_names" ]
593+                     if  len (sub_target_names ) !=  3  or  any (
594+                         not  (isinstance (s , str ) or  s  is  None ) for  s  in  sub_target_names 
595+                     ):
596+                         raise  ValueError (f"Unexpected target sub-names { sub_target_names } { target_name }  )
597+                     for  i , sub_target_name  in  enumerate (sub_target_names ):
598+                         if  sub_target_name  is  None :
599+                             continue 
600+                         sub_img  =  img [:, :, i ]
601+                         sub_prediction_space  =  target_properties [sub_target_name ].get ("prediction_space" , "srgb" )
602+                         if  sub_prediction_space  ==  "linear" :
603+                             sub_up_to_scale  =  target_properties [sub_target_name ].get ("up_to_scale" , False )
604+                             if  sub_up_to_scale :
605+                                 sub_img  =  sub_img  /  max (sub_img .max ().item (), 1e-6 )
606+                             sub_img  =  sub_img  **  (1  /  2.2 )
607+                         cmap_name  =  (
608+                             color_map  if  isinstance (color_map , str ) else  color_map .get (sub_target_name , "binary" )
609+                         )
610+                         sub_img  =  MarigoldImageProcessor .colormap (sub_img , cmap = cmap_name , bytes = True )
611+                         sub_img  =  PIL .Image .fromarray (sub_img .cpu ().numpy ())
612+                         out [sub_target_name ] =  sub_img 
613+                 elif  prediction_space  ==  "linear" :
614+                     up_to_scale  =  target_properties [target_name ].get ("up_to_scale" , False )
615+                     if  up_to_scale :
616+                         img  =  img  /  max (img .max ().item (), 1e-6 )
617+                     img  =  img  **  (1  /  2.2 )
618+                 elif  prediction_space  ==  "srgb" :
619+                     pass 
620+                 img  =  (img  *  255 ).to (dtype = torch .uint8 , device = "cpu" ).numpy ()
621+                 img  =  PIL .Image .fromarray (img )
622+                 out [target_name ] =  img 
623+             return  out 
624+ 
625+         if  prediction  is  None  or  isinstance (prediction , list ) and  any (o  is  None  for  o  in  prediction ):
626+             raise  ValueError ("Input prediction is `None`" )
627+         if  isinstance (prediction , (np .ndarray , torch .Tensor )):
628+             prediction  =  MarigoldImageProcessor .expand_tensor_or_array (prediction )
629+             if  isinstance (prediction , np .ndarray ):
630+                 prediction  =  MarigoldImageProcessor .numpy_to_pt (prediction )  # [N*T,3,H,W] 
631+             if  not  (prediction .ndim  ==  4  and  prediction .shape [1 ] ==  3  and  prediction .shape [0 ] %  n_targets  ==  0 ):
632+                 raise  ValueError (f"Unexpected input shape={ prediction .shape }  )
633+             N_T , _ , H , W  =  prediction .shape 
634+             N  =  N_T  //  n_targets 
635+             prediction  =  prediction .reshape (N , n_targets , 3 , H , W )
636+             return  [visualize_targets_one (img , idx ) for  idx , img  in  enumerate (prediction )]
637+         elif  isinstance (prediction , list ):
638+             return  [visualize_targets_one (img , idx ) for  idx , img  in  enumerate (prediction )]
639+         else :
640+             raise  ValueError (f"Unexpected input type: { type (prediction )}  )
641+ 
531642    @staticmethod  
532643    def  visualize_uncertainty (
533644        uncertainty : Union [
@@ -537,24 +648,26 @@ def visualize_uncertainty(
537648            List [torch .Tensor ],
538649        ],
539650        saturation_percentile = 95 ,
540-     ) ->  Union [ PIL . Image . Image ,  List [PIL .Image .Image ] ]:
651+     ) ->  List [PIL .Image .Image ]:
541652        """ 
542-         Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. 
653+         Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or 
654+         `MarigoldIntrinsicsPipeline`. 
543655
544656        Args: 
545657            uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): 
546658                Uncertainty maps. 
547659            saturation_percentile (`int`, *optional*, defaults to `95`): 
548660                Specifies the percentile uncertainty value visualized with maximum intensity. 
549661
550-         Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with uncertainty visualization. 
662+         Returns: `List[PIL.Image.Image]` with uncertainty visualization. 
551663        """ 
552664
553665        def  visualize_uncertainty_one (img , idx = None ):
554666            prefix  =  "Uncertainty"  +  (f"[{ idx }   if  idx  else  "" )
555667            if  img .min () <  0 :
556-                 raise  ValueError (f"{ prefix } { img .min ()}  )
557-             img  =  img .squeeze (0 ).cpu ().numpy ()
668+                 raise  ValueError (f"{ prefix } { img .min ()}  )
669+             img  =  img .permute (1 , 2 , 0 )  # [H,W,C] 
670+             img  =  img .squeeze (2 ).cpu ().numpy ()  # [H,W] or [H,W,3] 
558671            saturation_value  =  np .percentile (img , saturation_percentile )
559672            img  =  np .clip (img  *  255  /  saturation_value , 0 , 255 )
560673            img  =  img .astype (np .uint8 )
@@ -566,9 +679,9 @@ def visualize_uncertainty_one(img, idx=None):
566679        if  isinstance (uncertainty , (np .ndarray , torch .Tensor )):
567680            uncertainty  =  MarigoldImageProcessor .expand_tensor_or_array (uncertainty )
568681            if  isinstance (uncertainty , np .ndarray ):
569-                 uncertainty  =  MarigoldImageProcessor .numpy_to_pt (uncertainty )  # [N,1 ,H,W] 
570-             if  not  (uncertainty .ndim  ==  4  and  uncertainty .shape [1 ] ==   1 ):
571-                 raise  ValueError (f"Unexpected input shape={ uncertainty .shape } 1 ,H,W]." )
682+                 uncertainty  =  MarigoldImageProcessor .numpy_to_pt (uncertainty )  # [N,C ,H,W] 
683+             if  not  (uncertainty .ndim  ==  4  and  uncertainty .shape [1 ] in  ( 1 ,  3 ) ):
684+                 raise  ValueError (f"Unexpected input shape={ uncertainty .shape } C ,H,W] with C in (1,3) ." )
572685            return  [visualize_uncertainty_one (img , idx ) for  idx , img  in  enumerate (uncertainty )]
573686        elif  isinstance (uncertainty , list ):
574687            return  [visualize_uncertainty_one (img , idx ) for  idx , img  in  enumerate (uncertainty )]
0 commit comments