Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 76fef70

Browse files
authored
Fix reproducibility (#2561)
* Fix in random seed * Fix * Fix * Fix * Fix * Fix in test * Fix in tests * Fix in tests
1 parent 195edf7 commit 76fef70

File tree

5 files changed

+456
-12
lines changed

5 files changed

+456
-12
lines changed

.cursor/rules/albumentations-rules.mdc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ alwaysApply: true
1111
- we do not use fill_value, but fill. Not fill_mask_value, but fill_mask
1212
- We do not have ANY default values in the InitSchema class
1313
- Use pytest.mark.parametrize for parameterized tests
14+
- In the code, when need default value use 137, not 42

albumentations/core/composition.py

Lines changed: 130 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,13 @@ def set_random_seed(self, seed: int | None) -> None:
193193
seed (int | None): Random seed to use
194194
195195
"""
196+
# Store the original seed
196197
self.seed = seed
198+
199+
# Use base seed directly (subclasses like Compose can override this)
197200
self.random_generator = np.random.default_rng(seed)
198201
self.py_random = random.Random(seed)
202+
199203
# Propagate seed to all transforms
200204
for transform in self.transforms:
201205
if isinstance(transform, (BasicTransform, BaseCompose)):
@@ -572,6 +576,35 @@ def _get_init_params(self) -> dict[str, Any]:
572576
"p": self.p,
573577
}
574578

579+
def _get_effective_seed(self, base_seed: int | None) -> int | None:
580+
"""Get effective seed considering worker context.
581+
582+
Args:
583+
base_seed (int | None): Base seed value
584+
585+
Returns:
586+
int | None: Effective seed after considering worker context
587+
588+
"""
589+
if base_seed is None:
590+
return base_seed
591+
592+
try:
593+
import torch
594+
import torch.utils.data
595+
596+
worker_info = torch.utils.data.get_worker_info()
597+
if worker_info is not None:
598+
# We're in a DataLoader worker process
599+
# Use torch.initial_seed() which is unique per worker and changes on respawn
600+
torch_seed = torch.initial_seed() % (2**32)
601+
return (base_seed + torch_seed) % (2**32)
602+
except (ImportError, AttributeError):
603+
# PyTorch not available or not in worker context
604+
pass
605+
606+
return base_seed
607+
575608

576609
class Compose(BaseCompose, HubMixin):
577610
"""Compose multiple transforms together and apply them sequentially to input data.
@@ -676,11 +709,17 @@ def __init__(
676709
seed: int | None = None,
677710
save_applied_params: bool = False,
678711
):
712+
# Store the original base seed for worker context recalculation
713+
self._base_seed = seed
714+
715+
# Get effective seed considering worker context
716+
effective_seed = self._get_effective_seed(seed)
717+
679718
super().__init__(
680719
transforms=transforms,
681720
p=p,
682721
mask_interpolation=mask_interpolation,
683-
seed=seed,
722+
seed=effective_seed,
684723
save_applied_params=save_applied_params,
685724
)
686725

@@ -725,6 +764,7 @@ def __init__(
725764
self.save_applied_params = save_applied_params
726765
self._images_was_list = False
727766
self._masks_was_list = False
767+
self._last_torch_seed: int | None = None
728768

729769
@property
730770
def strict(self) -> bool:
@@ -788,7 +828,7 @@ def disable_check_args_private(self) -> None:
788828
self.main_compose = False
789829

790830
def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
791-
"""Apply transformations to data.
831+
"""Apply transformations to data with automatic worker seed synchronization.
792832
793833
Args:
794834
*args (Any): Positional arguments are not supported.
@@ -802,14 +842,13 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s
802842
KeyError: If positional arguments are provided.
803843
804844
"""
845+
# Check and sync worker seed if needed
846+
self._check_worker_seed()
847+
805848
if args:
806849
msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
807850
raise KeyError(msg)
808851

809-
if not isinstance(force_apply, (bool, int)):
810-
msg = "force_apply must have bool or int type"
811-
raise TypeError(msg)
812-
813852
# Initialize applied_transforms only in top-level Compose if requested
814853
if self.save_applied_params and self.main_compose:
815854
data["applied_transforms"] = []
@@ -827,6 +866,84 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s
827866

828867
return self.postprocess(data)
829868

869+
def _check_worker_seed(self) -> None:
870+
"""Check and update random seed if in worker context."""
871+
if not hasattr(self, "_base_seed") or self._base_seed is None:
872+
return
873+
874+
# Check if we're in a worker and need to update the seed
875+
try:
876+
import torch
877+
import torch.utils.data
878+
879+
worker_info = torch.utils.data.get_worker_info()
880+
if worker_info is not None:
881+
# Get the current torch initial seed
882+
current_torch_seed = torch.initial_seed()
883+
884+
# Check if we've already synchronized for this seed
885+
if hasattr(self, "_last_torch_seed") and self._last_torch_seed == current_torch_seed:
886+
return
887+
888+
# Update the seed and mark as synchronized
889+
self._last_torch_seed = current_torch_seed
890+
effective_seed = self._get_effective_seed(self._base_seed)
891+
892+
# Update our own random state
893+
self.random_generator = np.random.default_rng(effective_seed)
894+
self.py_random = random.Random(effective_seed)
895+
896+
# Propagate to all transforms
897+
for transform in self.transforms:
898+
if hasattr(transform, "set_random_state"):
899+
transform.set_random_state(self.random_generator, self.py_random)
900+
elif hasattr(transform, "set_random_seed"):
901+
# For transforms that don't have set_random_state, use set_random_seed
902+
transform.set_random_seed(effective_seed)
903+
except (ImportError, AttributeError):
904+
pass
905+
906+
def __setstate__(self, state: dict[str, Any]) -> None:
907+
"""Set state from unpickling and handle worker seed."""
908+
self.__dict__.update(state)
909+
# If we have a base seed, recalculate effective seed in worker context
910+
if hasattr(self, "_base_seed") and self._base_seed is not None:
911+
# Reset _last_torch_seed to ensure worker-seed sync runs after unpickling
912+
self._last_torch_seed = None
913+
# Recalculate effective seed in worker context
914+
self.set_random_seed(self._base_seed)
915+
elif hasattr(self, "seed") and self.seed is not None:
916+
# For backward compatibility, if no base seed but seed exists
917+
self._base_seed = self.seed
918+
self._last_torch_seed = None
919+
self.set_random_seed(self.seed)
920+
921+
def set_random_seed(self, seed: int | None) -> None:
922+
"""Override to use worker-aware seed functionality.
923+
924+
Args:
925+
seed (int | None): Random seed to use
926+
927+
"""
928+
# Store the original base seed
929+
self._base_seed = seed
930+
self.seed = seed
931+
932+
# Get effective seed considering worker context
933+
effective_seed = self._get_effective_seed(seed)
934+
935+
# Initialize random generators with effective seed
936+
self.random_generator = np.random.default_rng(effective_seed)
937+
self.py_random = random.Random(effective_seed)
938+
939+
# Propagate to all transforms
940+
for transform in self.transforms:
941+
if hasattr(transform, "set_random_state"):
942+
transform.set_random_state(self.random_generator, self.py_random)
943+
elif hasattr(transform, "set_random_seed"):
944+
# For transforms that don't have set_random_state, use set_random_seed
945+
transform.set_random_seed(effective_seed)
946+
830947
def preprocess(self, data: Any) -> None:
831948
"""Preprocess input data before applying transforms."""
832949
# Always validate shapes if is_check_shapes is True, regardless of strict mode
@@ -959,6 +1076,7 @@ def to_dict_private(self) -> dict[str, Any]:
9591076
"keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None),
9601077
"additional_targets": self.additional_targets,
9611078
"is_check_shapes": self.is_check_shapes,
1079+
"seed": getattr(self, "_base_seed", None),
9621080
},
9631081
)
9641082
return dictionary
@@ -1201,7 +1319,7 @@ def _get_init_params(self) -> dict[str, Any]:
12011319
"is_check_shapes": self.is_check_shapes,
12021320
"strict": self.strict,
12031321
"mask_interpolation": getattr(self, "mask_interpolation", None),
1204-
"seed": getattr(self, "seed", None),
1322+
"seed": getattr(self, "_base_seed", None),
12051323
"save_applied_params": getattr(self, "save_applied_params", False),
12061324
}
12071325

@@ -1445,7 +1563,7 @@ def __init__(
14451563
msg = "You must set both first and second or set transforms argument."
14461564
raise ValueError(msg)
14471565
transforms = [first, second]
1448-
super().__init__(transforms, p)
1566+
super().__init__(transforms=transforms, p=p)
14491567
if len(self.transforms) != NUM_ONEOF_TRANSFORMS:
14501568
warnings.warn("Length of transforms is not equal to 2.", stacklevel=2)
14511569

@@ -1503,7 +1621,7 @@ def __init__(
15031621
channels: Sequence[int] = (0, 1, 2),
15041622
p: float = 1.0,
15051623
) -> None:
1506-
super().__init__(transforms, p)
1624+
super().__init__(transforms=transforms, p=p)
15071625
self.channels = channels
15081626

15091627
def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
@@ -1525,8 +1643,9 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s
15251643
sub_image = np.ascontiguousarray(selected_channels)
15261644

15271645
for t in self.transforms:
1528-
sub_image = t(image=sub_image)["image"]
1529-
self._track_transform_params(t, sub_image)
1646+
sub_data = {"image": sub_image}
1647+
sub_image = t(**sub_data)["image"]
1648+
self._track_transform_params(t, sub_data)
15301649

15311650
transformed_channels = cv2.split(sub_image)
15321651
output_img = image.copy()

0 commit comments

Comments
 (0)