@@ -193,9 +193,13 @@ def set_random_seed(self, seed: int | None) -> None:
193193 seed (int | None): Random seed to use
194194
195195 """
196+ # Store the original seed
196197 self .seed = seed
198+
199+ # Use base seed directly (subclasses like Compose can override this)
197200 self .random_generator = np .random .default_rng (seed )
198201 self .py_random = random .Random (seed )
202+
199203 # Propagate seed to all transforms
200204 for transform in self .transforms :
201205 if isinstance (transform , (BasicTransform , BaseCompose )):
@@ -572,6 +576,35 @@ def _get_init_params(self) -> dict[str, Any]:
572576 "p" : self .p ,
573577 }
574578
579+ def _get_effective_seed (self , base_seed : int | None ) -> int | None :
580+ """Get effective seed considering worker context.
581+
582+ Args:
583+ base_seed (int | None): Base seed value
584+
585+ Returns:
586+ int | None: Effective seed after considering worker context
587+
588+ """
589+ if base_seed is None :
590+ return base_seed
591+
592+ try :
593+ import torch
594+ import torch .utils .data
595+
596+ worker_info = torch .utils .data .get_worker_info ()
597+ if worker_info is not None :
598+ # We're in a DataLoader worker process
599+ # Use torch.initial_seed() which is unique per worker and changes on respawn
600+ torch_seed = torch .initial_seed () % (2 ** 32 )
601+ return (base_seed + torch_seed ) % (2 ** 32 )
602+ except (ImportError , AttributeError ):
603+ # PyTorch not available or not in worker context
604+ pass
605+
606+ return base_seed
607+
575608
576609class Compose (BaseCompose , HubMixin ):
577610 """Compose multiple transforms together and apply them sequentially to input data.
@@ -676,11 +709,17 @@ def __init__(
676709 seed : int | None = None ,
677710 save_applied_params : bool = False ,
678711 ):
712+ # Store the original base seed for worker context recalculation
713+ self ._base_seed = seed
714+
715+ # Get effective seed considering worker context
716+ effective_seed = self ._get_effective_seed (seed )
717+
679718 super ().__init__ (
680719 transforms = transforms ,
681720 p = p ,
682721 mask_interpolation = mask_interpolation ,
683- seed = seed ,
722+ seed = effective_seed ,
684723 save_applied_params = save_applied_params ,
685724 )
686725
@@ -725,6 +764,7 @@ def __init__(
725764 self .save_applied_params = save_applied_params
726765 self ._images_was_list = False
727766 self ._masks_was_list = False
767+ self ._last_torch_seed : int | None = None
728768
729769 @property
730770 def strict (self ) -> bool :
@@ -788,7 +828,7 @@ def disable_check_args_private(self) -> None:
788828 self .main_compose = False
789829
790830 def __call__ (self , * args : Any , force_apply : bool = False , ** data : Any ) -> dict [str , Any ]:
791- """Apply transformations to data.
831+ """Apply transformations to data with automatic worker seed synchronization .
792832
793833 Args:
794834 *args (Any): Positional arguments are not supported.
@@ -802,14 +842,13 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s
802842 KeyError: If positional arguments are provided.
803843
804844 """
845+ # Check and sync worker seed if needed
846+ self ._check_worker_seed ()
847+
805848 if args :
806849 msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
807850 raise KeyError (msg )
808851
809- if not isinstance (force_apply , (bool , int )):
810- msg = "force_apply must have bool or int type"
811- raise TypeError (msg )
812-
813852 # Initialize applied_transforms only in top-level Compose if requested
814853 if self .save_applied_params and self .main_compose :
815854 data ["applied_transforms" ] = []
@@ -827,6 +866,84 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s
827866
828867 return self .postprocess (data )
829868
869+ def _check_worker_seed (self ) -> None :
870+ """Check and update random seed if in worker context."""
871+ if not hasattr (self , "_base_seed" ) or self ._base_seed is None :
872+ return
873+
874+ # Check if we're in a worker and need to update the seed
875+ try :
876+ import torch
877+ import torch .utils .data
878+
879+ worker_info = torch .utils .data .get_worker_info ()
880+ if worker_info is not None :
881+ # Get the current torch initial seed
882+ current_torch_seed = torch .initial_seed ()
883+
884+ # Check if we've already synchronized for this seed
885+ if hasattr (self , "_last_torch_seed" ) and self ._last_torch_seed == current_torch_seed :
886+ return
887+
888+ # Update the seed and mark as synchronized
889+ self ._last_torch_seed = current_torch_seed
890+ effective_seed = self ._get_effective_seed (self ._base_seed )
891+
892+ # Update our own random state
893+ self .random_generator = np .random .default_rng (effective_seed )
894+ self .py_random = random .Random (effective_seed )
895+
896+ # Propagate to all transforms
897+ for transform in self .transforms :
898+ if hasattr (transform , "set_random_state" ):
899+ transform .set_random_state (self .random_generator , self .py_random )
900+ elif hasattr (transform , "set_random_seed" ):
901+ # For transforms that don't have set_random_state, use set_random_seed
902+ transform .set_random_seed (effective_seed )
903+ except (ImportError , AttributeError ):
904+ pass
905+
906+ def __setstate__ (self , state : dict [str , Any ]) -> None :
907+ """Set state from unpickling and handle worker seed."""
908+ self .__dict__ .update (state )
909+ # If we have a base seed, recalculate effective seed in worker context
910+ if hasattr (self , "_base_seed" ) and self ._base_seed is not None :
911+ # Reset _last_torch_seed to ensure worker-seed sync runs after unpickling
912+ self ._last_torch_seed = None
913+ # Recalculate effective seed in worker context
914+ self .set_random_seed (self ._base_seed )
915+ elif hasattr (self , "seed" ) and self .seed is not None :
916+ # For backward compatibility, if no base seed but seed exists
917+ self ._base_seed = self .seed
918+ self ._last_torch_seed = None
919+ self .set_random_seed (self .seed )
920+
921+ def set_random_seed (self , seed : int | None ) -> None :
922+ """Override to use worker-aware seed functionality.
923+
924+ Args:
925+ seed (int | None): Random seed to use
926+
927+ """
928+ # Store the original base seed
929+ self ._base_seed = seed
930+ self .seed = seed
931+
932+ # Get effective seed considering worker context
933+ effective_seed = self ._get_effective_seed (seed )
934+
935+ # Initialize random generators with effective seed
936+ self .random_generator = np .random .default_rng (effective_seed )
937+ self .py_random = random .Random (effective_seed )
938+
939+ # Propagate to all transforms
940+ for transform in self .transforms :
941+ if hasattr (transform , "set_random_state" ):
942+ transform .set_random_state (self .random_generator , self .py_random )
943+ elif hasattr (transform , "set_random_seed" ):
944+ # For transforms that don't have set_random_state, use set_random_seed
945+ transform .set_random_seed (effective_seed )
946+
830947 def preprocess (self , data : Any ) -> None :
831948 """Preprocess input data before applying transforms."""
832949 # Always validate shapes if is_check_shapes is True, regardless of strict mode
@@ -959,6 +1076,7 @@ def to_dict_private(self) -> dict[str, Any]:
9591076 "keypoint_params" : (keypoints_processor .params .to_dict_private () if keypoints_processor else None ),
9601077 "additional_targets" : self .additional_targets ,
9611078 "is_check_shapes" : self .is_check_shapes ,
1079+ "seed" : getattr (self , "_base_seed" , None ),
9621080 },
9631081 )
9641082 return dictionary
@@ -1201,7 +1319,7 @@ def _get_init_params(self) -> dict[str, Any]:
12011319 "is_check_shapes" : self .is_check_shapes ,
12021320 "strict" : self .strict ,
12031321 "mask_interpolation" : getattr (self , "mask_interpolation" , None ),
1204- "seed" : getattr (self , "seed " , None ),
1322+ "seed" : getattr (self , "_base_seed " , None ),
12051323 "save_applied_params" : getattr (self , "save_applied_params" , False ),
12061324 }
12071325
@@ -1445,7 +1563,7 @@ def __init__(
14451563 msg = "You must set both first and second or set transforms argument."
14461564 raise ValueError (msg )
14471565 transforms = [first , second ]
1448- super ().__init__ (transforms , p )
1566+ super ().__init__ (transforms = transforms , p = p )
14491567 if len (self .transforms ) != NUM_ONEOF_TRANSFORMS :
14501568 warnings .warn ("Length of transforms is not equal to 2." , stacklevel = 2 )
14511569
@@ -1503,7 +1621,7 @@ def __init__(
15031621 channels : Sequence [int ] = (0 , 1 , 2 ),
15041622 p : float = 1.0 ,
15051623 ) -> None :
1506- super ().__init__ (transforms , p )
1624+ super ().__init__ (transforms = transforms , p = p )
15071625 self .channels = channels
15081626
15091627 def __call__ (self , * args : Any , force_apply : bool = False , ** data : Any ) -> dict [str , Any ]:
@@ -1525,8 +1643,9 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s
15251643 sub_image = np .ascontiguousarray (selected_channels )
15261644
15271645 for t in self .transforms :
1528- sub_image = t (image = sub_image )["image" ]
1529- self ._track_transform_params (t , sub_image )
1646+ sub_data = {"image" : sub_image }
1647+ sub_image = t (** sub_data )["image" ]
1648+ self ._track_transform_params (t , sub_data )
15301649
15311650 transformed_channels = cv2 .split (sub_image )
15321651 output_img = image .copy ()
0 commit comments