@@ -134,6 +134,25 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
134134 torch .testing .assert_close (tf_actual (img_tensor ), tf_expected (img_tensor ))
135135
136136
137+ @pytest .mark .parametrize ("degrees, translate" , [((- 5.0 , 5.0 ), (0.05 , 0.05 )), ((10.0 , 10.0 ), (0.1 , 0.1 ))])
138+ def test_get_image_transforms_affine (img_tensor_factory , degrees , translate ):
139+ img_tensor = img_tensor_factory ()
140+ tf_cfg = ImageTransformsConfig (
141+ enable = True ,
142+ tfs = {
143+ "affine" : ImageTransformConfig (
144+ type = "RandomAffine" , kwargs = {"degrees" : degrees , "translate" : translate }
145+ )
146+ },
147+ )
148+ tf = ImageTransforms (tf_cfg )
149+ output = tf (img_tensor )
150+ # Verify output shape is preserved
151+ assert output .shape == img_tensor .shape
152+ # Verify transform is type RandomAffine
153+ assert isinstance (tf .transforms ["affine" ], v2 .RandomAffine )
154+
155+
137156def test_get_image_transforms_max_num_transforms (img_tensor_factory ):
138157 img_tensor = img_tensor_factory ()
139158 tf_cfg = ImageTransformsConfig (
@@ -262,7 +281,37 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
262281 # NOTE: PyTorch versions have different randomness, it might break this test.
263282 # See this PR: https://github.com/huggingface/lerobot/pull/1127.
264283
265- cfg = ImageTransformsConfig (enable = True )
284+ # Use config without affine to match original test artifacts
285+ cfg = ImageTransformsConfig (
286+ enable = True ,
287+ tfs = {
288+ "brightness" : ImageTransformConfig (
289+ weight = 1.0 ,
290+ type = "ColorJitter" ,
291+ kwargs = {"brightness" : (0.8 , 1.2 )},
292+ ),
293+ "contrast" : ImageTransformConfig (
294+ weight = 1.0 ,
295+ type = "ColorJitter" ,
296+ kwargs = {"contrast" : (0.8 , 1.2 )},
297+ ),
298+ "saturation" : ImageTransformConfig (
299+ weight = 1.0 ,
300+ type = "ColorJitter" ,
301+ kwargs = {"saturation" : (0.5 , 1.5 )},
302+ ),
303+ "hue" : ImageTransformConfig (
304+ weight = 1.0 ,
305+ type = "ColorJitter" ,
306+ kwargs = {"hue" : (- 0.05 , 0.05 )},
307+ ),
308+ "sharpness" : ImageTransformConfig (
309+ weight = 1.0 ,
310+ type = "SharpnessJitter" ,
311+ kwargs = {"sharpness" : (0.5 , 1.5 )},
312+ ),
313+ },
314+ )
266315 default_tf = ImageTransforms (cfg )
267316
268317 with seeded_context (1337 ):
@@ -368,7 +417,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
368417 save_each_transform (tf_cfg , img_tensor , tmp_path , n_examples )
369418
370419 # Check if the transformed images exist for each transform type
371- transforms = ["brightness" , "contrast" , "saturation" , "hue" , "sharpness" ]
420+ transforms = ["brightness" , "contrast" , "saturation" , "hue" , "sharpness" , "affine" ]
372421 for transform in transforms :
373422 transform_dir = tmp_path / transform
374423 assert transform_dir .exists (), f"{ transform } directory was not created."
0 commit comments