1010import  torch 
1111from  PIL  import  __version__  as  PILLOW_VERSION_STRING , Image , ImageColor , ImageDraw , ImageFont 
1212
13- 
1413__all__  =  [
1514    "_Image_fromarray" ,
1615    "make_grid" ,
@@ -293,6 +292,7 @@ def draw_bounding_boxes(
293292    font : Optional [str ] =  None ,
294293    font_size : Optional [int ] =  None ,
295294    label_colors : Optional [Union [list [Union [str , tuple [int , int , int ]]], str , tuple [int , int , int ]]] =  None ,
295+     label_background_colors : Optional [Union [list [Union [str , tuple [int , int , int ]]], str , tuple [int , int , int ]]] =  None ,
296296    fill_labels : bool  =  False ,
297297) ->  torch .Tensor :
298298    """ 
@@ -320,7 +320,10 @@ def draw_bounding_boxes(
320320        font_size (int): The requested font size in points. 
321321        label_colors (color or list of colors, optional): Colors for the label text.  See the description of the 
322322            `colors` argument for details.  Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True. 
323-         fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False. 
323+         label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the 
324+             same colors used for the boxes. Ignored when ``fill_labels`` is False. 
325+         fill_labels (bool): If `True` fills the label background with specified color (from the ``label_background_colors`` parameter, 
326+             or from the ``colors`` parameter if not specified). Default: False. 
324327
325328    Returns: 
326329        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. 
@@ -362,6 +365,11 @@ def draw_bounding_boxes(
362365    else :
363366        label_colors  =  colors .copy ()  # type: ignore[assignment] 
364367
368+     if  fill_labels  and  label_background_colors :
369+         label_background_colors  =  _parse_colors (label_background_colors , num_objects = num_boxes )
370+     else :
371+         label_background_colors  =  colors .copy ()  # type: ignore[assignment] 
372+ 
365373    if  font  is  None :
366374        if  font_size  is  not None :
367375            warnings .warn ("Argument 'font_size' will be ignored since 'font' is not set." )
@@ -385,7 +393,7 @@ def draw_bounding_boxes(
385393    else :
386394        draw  =  _ImageDrawTV (img_to_draw )
387395
388-     for  bbox , color , label , label_color   in  zip (img_boxes , colors , labels , label_colors ):  # type: ignore[arg-type] 
396+     for  bbox , color , label , label_color ,  label_bg_color   in  zip (img_boxes , colors , labels , label_colors ,  label_background_colors ):  # type: ignore[arg-type] 
389397        draw_method  =  draw .oriented_rectangle  if  len (bbox ) >  4  else  draw .rectangle 
390398        fill_color  =  color  +  (100 ,) if  fill  else  None 
391399        draw_method (bbox , width = width , outline = color , fill = fill_color )
@@ -396,7 +404,7 @@ def draw_bounding_boxes(
396404            if  fill_labels :
397405                left , top , right , bottom  =  draw .textbbox ((bbox [0 ] +  margin , bbox [1 ] +  margin ), label , font = txt_font )
398406                draw .rectangle (
399-                     (left  -  box_margin , top  -  box_margin , right  +  box_margin , bottom  +  box_margin ), fill = color 
407+                     (left  -  box_margin , top  -  box_margin , right  +  box_margin , bottom  +  box_margin ), fill = label_bg_color 
400408                )
401409            draw .text ((bbox [0 ] +  margin , bbox [1 ] +  margin ), label , fill = label_color , font = txt_font )  # type: ignore[arg-type] 
402410
@@ -545,7 +553,7 @@ def draw_keypoints(
545553    if  visibility .shape  !=  keypoints .shape [:- 1 ]:
546554        raise  ValueError (
547555            "keypoints and visibility must have the same dimensionality for num_instances and K. " 
548-             f"Got { visibility .shape   =   } { keypoints .shape   =   }  
556+             f"Got { visibility .shape = } { keypoints .shape = }  
549557        )
550558
551559    original_dtype  =  image .dtype 
@@ -746,7 +754,7 @@ def _parse_colors(
746754                f"Number of colors must be equal or larger than the number of objects, but got { len (colors )} { num_objects }  
747755            )
748756    elif  not  isinstance (colors , (tuple , str )):
749-         raise  ValueError (f"` colors`  must be a tuple or a string, or a list thereof, but got { colors }  )
757+         raise  ValueError (f"colors must be a tuple or a string, or a list thereof, but got { colors }  )
750758    elif  isinstance (colors , tuple ) and  len (colors ) !=  3 :
751759        raise  ValueError (f"If passed as tuple, colors should be an RGB triplet, but got { colors }  )
752760    else :  # colors specifies a single color for all objects 
0 commit comments