@@ -127,14 +127,16 @@ def rgb_to_lab_tensor(
127127 rgb_img : torch .Tensor ,
128128 normalized : bool = True ,
129129 srgb_input : bool = True ,
130- ) -> torch .Tensor :
130+ split_channels : bool = False ,
131+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]:
131132 """
132133 Convert RGB image to LAB color space using tensor operations.
133134
134135 Args:
135136 rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
136137 normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
137-
138+ srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline)
139+ split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image
138140 Returns:
139141 lab_img: Tensor of same shape with either:
140142 - normalized=False: L in [0, 100] and a,b in [-128, 127]
@@ -152,13 +154,14 @@ def rgb_to_lab_tensor(
152154 rgb_img = srgb_to_linear (rgb_img )
153155
154156 # FIXME transforms before this are causing -ve values, can have a large impact on this conversion
155- rgb_img . clamp_ (0 , 1.0 )
157+ rgb_img = rgb_img . clamp (0 , 1.0 )
156158
157159 # Convert to XYZ using matrix multiplication
158160 rgb_to_xyz = torch .tensor ([
159- [0.412453 , 0.357580 , 0.180423 ],
160- [0.212671 , 0.715160 , 0.072169 ],
161- [0.019334 , 0.119193 , 0.950227 ]
161+ # X Y Z
162+ [0.412453 , 0.212671 , 0.019334 ], # R
163+ [0.357580 , 0.715160 , 0.119193 ], # G
164+ [0.180423 , 0.072169 , 0.950227 ], # B
162165 ], device = rgb_img .device )
163166
164167 # Reshape input for matrix multiplication if needed
@@ -167,38 +170,30 @@ def rgb_to_lab_tensor(
167170 rgb_img = rgb_img .reshape (- 1 , 3 )
168171
169172 # Perform matrix multiplication
170- xyz = torch . matmul ( rgb_img , rgb_to_xyz . T )
173+ xyz = rgb_img @ rgb_to_xyz
171174
172175 # Adjust XYZ values
173- xyz [..., 0 ].div_ (xn )
174- xyz [..., 1 ].div_ (yn )
175- xyz [..., 2 ].div_ (zn )
176+ xyz .div_ (torch .tensor ([xn , yn , zn ], device = xyz .device ))
176177
177178 # Step 4: XYZ to LAB
178- lab = torch .where (
179+ fxfyfz = torch .where (
179180 xyz > epsilon ,
180181 torch .pow (xyz , 1 / 3 ),
181182 (kappa * xyz + 16 ) / 116
182183 )
183184
185+ L = 116 * fxfyfz [..., 1 ] - 16
186+ a = 500 * (fxfyfz [..., 0 ] - fxfyfz [..., 1 ])
187+ b = 200 * (fxfyfz [..., 1 ] - fxfyfz [..., 2 ])
184188 if normalized :
185- # Calculate normalized [0,1] L,a,b values directly
186- # L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
187- # a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
188- # b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
189- shift_128 = 128 / 255
190- a_scale = 500 / 255
191- b_scale = 200 / 255
192- L = 1.16 * lab [..., 1 ] - 0.16
193- a = a_scale * (lab [..., 0 ] - lab [..., 1 ]) + shift_128
194- b = b_scale * (lab [..., 1 ] - lab [..., 2 ]) + shift_128
195- else :
196- # Calculate native range L,a,b values
197- L = 116 * lab [..., 1 ] - 16
198- a = 500 * (lab [..., 0 ] - lab [..., 1 ])
199- b = 200 * (lab [..., 1 ] - lab [..., 2 ])
189+ # output in rage [0, 1] for each channel
190+ L .div_ (100 )
191+ a .add_ (128 ).div_ (255 )
192+ b .add_ (128 ).div_ (255 )
193+
194+ if split_channels :
195+ return L , a , b
200196
201- # Stack the results
202197 lab = torch .stack ([L , a , b ], dim = - 1 )
203198
204199 # Restore original shape if needed
0 commit comments