@@ -90,7 +90,7 @@ def __repr__(self) -> str:
9090 return f"{ self .__class__ .__name__ } ()"
9191
9292
93- class ToLab ( transforms . ToTensor ) :
93+ class ToLabPIL :
9494
9595 def __init__ (self ) -> None :
9696 super ().__init__ ()
@@ -115,6 +115,121 @@ def __repr__(self) -> str:
115115 return f"{ self .__class__ .__name__ } ()"
116116
117117
118+ def srgb_to_linear (srgb_image : torch .Tensor ) -> torch .Tensor :
119+ return torch .where (
120+ srgb_image <= 0.04045 ,
121+ srgb_image / 12.92 ,
122+ ((srgb_image + 0.055 ) / 1.055 ) ** 2.4
123+ )
124+
125+
126+ def rgb_to_lab_tensor (
127+ rgb_img : torch .Tensor ,
128+ normalized : bool = True ,
129+ srgb_input : bool = True ,
130+ ) -> torch .Tensor :
131+ """
132+ Convert RGB image to LAB color space using tensor operations.
133+
134+ Args:
135+ rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
136+ normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
137+
138+ Returns:
139+ lab_img: Tensor of same shape with either:
140+ - normalized=False: L in [0, 100] and a,b in [-128, 127]
141+ - normalized=True: L,a,b in [0, 1]
142+ """
143+ # Constants
144+ epsilon = 216 / 24389
145+ kappa = 24389 / 27
146+ xn = 0.95047
147+ yn = 1.0
148+ zn = 1.08883
149+
150+ # Convert sRGB to linear RGB
151+ if srgb_input :
152+ rgb_img = srgb_to_linear (rgb_img )
153+
154+ # FIXME transforms before this are causing -ve values, can have a large impact on this conversion
155+ rgb_img .clamp_ (0 , 1.0 )
156+
157+ # Convert to XYZ using matrix multiplication
158+ rgb_to_xyz = torch .tensor ([
159+ [0.412453 , 0.357580 , 0.180423 ],
160+ [0.212671 , 0.715160 , 0.072169 ],
161+ [0.019334 , 0.119193 , 0.950227 ]
162+ ], device = rgb_img .device )
163+
164+ # Reshape input for matrix multiplication if needed
165+ original_shape = rgb_img .shape
166+ if len (original_shape ) > 2 :
167+ rgb_img = rgb_img .reshape (- 1 , 3 )
168+
169+ # Perform matrix multiplication
170+ xyz = torch .matmul (rgb_img , rgb_to_xyz .T )
171+
172+ # Adjust XYZ values
173+ xyz [..., 0 ].div_ (xn )
174+ xyz [..., 1 ].div_ (yn )
175+ xyz [..., 2 ].div_ (zn )
176+
177+ # Step 4: XYZ to LAB
178+ lab = torch .where (
179+ xyz > epsilon ,
180+ torch .pow (xyz , 1 / 3 ),
181+ (kappa * xyz + 16 ) / 116
182+ )
183+
184+ if normalized :
185+ # Calculate normalized [0,1] L,a,b values directly
186+ # L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
187+ # a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
188+ # b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
189+ shift_128 = 128 / 255
190+ a_scale = 500 / 255
191+ b_scale = 200 / 255
192+ L = 1.16 * lab [..., 1 ] - 0.16
193+ a = a_scale * (lab [..., 0 ] - lab [..., 1 ]) + shift_128
194+ b = b_scale * (lab [..., 1 ] - lab [..., 2 ]) + shift_128
195+ else :
196+ # Calculate native range L,a,b values
197+ L = 116 * lab [..., 1 ] - 16
198+ a = 500 * (lab [..., 0 ] - lab [..., 1 ])
199+ b = 200 * (lab [..., 1 ] - lab [..., 2 ])
200+
201+ # Stack the results
202+ lab = torch .stack ([L , a , b ], dim = - 1 )
203+
204+ # Restore original shape if needed
205+ if len (original_shape ) > 2 :
206+ lab = lab .reshape (original_shape )
207+
208+ return lab
209+
210+
211+ class ToLabTensor :
212+ def __init__ (self , srgb_input = False , normalized = True ) -> None :
213+ self .srgb_input = srgb_input
214+ self .normalized = normalized
215+
216+ def __call__ (self , pic ) -> torch .Tensor :
217+ return rgb_to_lab_tensor (
218+ pic ,
219+ normalized = self .normalized ,
220+ srgb_input = self .srgb_input ,
221+ )
222+
223+
224+ class ToLinearRgb :
225+ def __init__ (self ):
226+ pass
227+
228+ def __call__ (self , pic ) -> torch .Tensor :
229+ assert isinstance (pic , torch .Tensor )
230+ return srgb_to_linear (pic )
231+
232+
118233# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
119234# favor of the Image.Resampling enum. The top-level resampling attributes will be
120235# removed in Pillow 10.
0 commit comments