Skip to content

Commit 8810094

Browse files
add affine transforms and test (#2145)
Co-authored-by: Steven Palma <[email protected]>
1 parent a95b15c commit 8810094

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

src/lerobot/datasets/transforms.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ class ImageTransformsConfig:
206206
type="SharpnessJitter",
207207
kwargs={"sharpness": (0.5, 1.5)},
208208
),
209+
"affine": ImageTransformConfig(
210+
weight=1.0,
211+
type="RandomAffine",
212+
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
213+
),
209214
}
210215
)
211216

@@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig):
217222
return v2.ColorJitter(**cfg.kwargs)
218223
elif cfg.type == "SharpnessJitter":
219224
return SharpnessJitter(**cfg.kwargs)
225+
elif cfg.type == "RandomAffine":
226+
return v2.RandomAffine(**cfg.kwargs)
220227
else:
221228
raise ValueError(f"Transform '{cfg.type}' is not valid.")
222229

tests/datasets/test_image_transforms.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,25 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
134134
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
135135

136136

137+
@pytest.mark.parametrize("degrees, translate", [((-5.0, 5.0), (0.05, 0.05)), ((10.0, 10.0), (0.1, 0.1))])
138+
def test_get_image_transforms_affine(img_tensor_factory, degrees, translate):
139+
img_tensor = img_tensor_factory()
140+
tf_cfg = ImageTransformsConfig(
141+
enable=True,
142+
tfs={
143+
"affine": ImageTransformConfig(
144+
type="RandomAffine", kwargs={"degrees": degrees, "translate": translate}
145+
)
146+
},
147+
)
148+
tf = ImageTransforms(tf_cfg)
149+
output = tf(img_tensor)
150+
# Verify output shape is preserved
151+
assert output.shape == img_tensor.shape
152+
# Verify transform is type RandomAffine
153+
assert isinstance(tf.transforms["affine"], v2.RandomAffine)
154+
155+
137156
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
138157
img_tensor = img_tensor_factory()
139158
tf_cfg = ImageTransformsConfig(
@@ -262,7 +281,37 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
262281
# NOTE: PyTorch versions have different randomness, it might break this test.
263282
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
264283

265-
cfg = ImageTransformsConfig(enable=True)
284+
# Use config without affine to match original test artifacts
285+
cfg = ImageTransformsConfig(
286+
enable=True,
287+
tfs={
288+
"brightness": ImageTransformConfig(
289+
weight=1.0,
290+
type="ColorJitter",
291+
kwargs={"brightness": (0.8, 1.2)},
292+
),
293+
"contrast": ImageTransformConfig(
294+
weight=1.0,
295+
type="ColorJitter",
296+
kwargs={"contrast": (0.8, 1.2)},
297+
),
298+
"saturation": ImageTransformConfig(
299+
weight=1.0,
300+
type="ColorJitter",
301+
kwargs={"saturation": (0.5, 1.5)},
302+
),
303+
"hue": ImageTransformConfig(
304+
weight=1.0,
305+
type="ColorJitter",
306+
kwargs={"hue": (-0.05, 0.05)},
307+
),
308+
"sharpness": ImageTransformConfig(
309+
weight=1.0,
310+
type="SharpnessJitter",
311+
kwargs={"sharpness": (0.5, 1.5)},
312+
),
313+
},
314+
)
266315
default_tf = ImageTransforms(cfg)
267316

268317
with seeded_context(1337):
@@ -368,7 +417,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
368417
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
369418

370419
# Check if the transformed images exist for each transform type
371-
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
420+
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness", "affine"]
372421
for transform in transforms:
373422
transform_dir = tmp_path / transform
374423
assert transform_dir.exists(), f"{transform} directory was not created."

0 commit comments

Comments
 (0)