@@ -633,24 +633,20 @@ def __call__(self, X, alpha=None, bytes=False):
633633 xa [xa < 0 ] = self ._i_under
634634 xa [mask_bad ] = self ._i_bad
635635
636+ lut = self ._lut
636637 if bytes :
637- lut = (self ._lut * 255 ).astype (np .uint8 )
638- else :
639- lut = self ._lut .copy () # Don't let alpha modify original _lut.
638+ lut = (lut * 255 ).astype (np .uint8 )
640639
641- rgba = np .empty (shape = xa .shape + (4 ,), dtype = lut .dtype )
642- lut .take (xa , axis = 0 , mode = 'clip' , out = rgba )
640+ rgba = lut .take (xa , axis = 0 , mode = 'clip' )
643641
644642 if alpha is not None :
645- if np .iterable (alpha ):
646- alpha = np .asarray (alpha )
647- if alpha .shape != xa .shape :
648- raise ValueError ("alpha is array-like but its shape"
649- " %s doesn't match that of X %s" %
650- (alpha .shape , xa .shape ))
651643 alpha = np .clip (alpha , 0 , 1 )
652644 if bytes :
653- alpha = (alpha * 255 ).astype (np .uint8 )
645+ alpha *= 255 # Will be cast to uint8 upon assignment.
646+ if alpha .shape not in [(), xa .shape ]:
647+ raise ValueError (
648+ f"alpha is array-like but its shape { alpha .shape } does "
649+ f"not match that of X { xa .shape } " )
654650 rgba [..., - 1 ] = alpha
655651
656652 # If the "bad" color is all zeros, then ignore alpha input.
0 commit comments