1616""" 
1717
1818import  functools 
19- from  typing  import  List , Optional , Tuple , Union 
19+ import  os 
20+ from  typing  import  Callable , Dict , List , Optional , Tuple , Union 
2021
2122from  . import  logging 
2223from  .import_utils  import  is_torch_available , is_torch_npu_available , is_torch_version 
@@ -36,6 +37,116 @@ def maybe_allow_in_graph(cls):
3637        return  cls 
3738
3839
40+ # Behaviour flags 
41+ BACKEND_SUPPORTS_TRAINING  =  {"cuda" : True , "xpu" : True , "cpu" : True , "mps" : False , "default" : True }
42+ # Function definitions 
43+ BACKEND_EMPTY_CACHE  =  {
44+     "cuda" : torch .cuda .empty_cache ,
45+     "xpu" : torch .xpu .empty_cache ,
46+     "cpu" : None ,
47+     "mps" : torch .mps .empty_cache ,
48+     "default" : None ,
49+ }
50+ BACKEND_DEVICE_COUNT  =  {
51+     "cuda" : torch .cuda .device_count ,
52+     "xpu" : torch .xpu .device_count ,
53+     "cpu" : lambda : 0 ,
54+     "mps" : lambda : 0 ,
55+     "default" : 0 ,
56+ }
57+ BACKEND_MANUAL_SEED  =  {
58+     "cuda" : torch .cuda .manual_seed ,
59+     "xpu" : torch .xpu .manual_seed ,
60+     "cpu" : torch .manual_seed ,
61+     "mps" : torch .mps .manual_seed ,
62+     "default" : torch .manual_seed ,
63+ }
64+ BACKEND_RESET_PEAK_MEMORY_STATS  =  {
65+     "cuda" : torch .cuda .reset_peak_memory_stats ,
66+     "xpu" : getattr (torch .xpu , "reset_peak_memory_stats" , None ),
67+     "cpu" : None ,
68+     "mps" : None ,
69+     "default" : None ,
70+ }
71+ BACKEND_RESET_MAX_MEMORY_ALLOCATED  =  {
72+     "cuda" : torch .cuda .reset_max_memory_allocated ,
73+     "xpu" : getattr (torch .xpu , "reset_peak_memory_stats" , None ),
74+     "cpu" : None ,
75+     "mps" : None ,
76+     "default" : None ,
77+ }
78+ BACKEND_MAX_MEMORY_ALLOCATED  =  {
79+     "cuda" : torch .cuda .max_memory_allocated ,
80+     "xpu" : getattr (torch .xpu , "max_memory_allocated" , None ),
81+     "cpu" : 0 ,
82+     "mps" : 0 ,
83+     "default" : 0 ,
84+ }
85+ BACKEND_SYNCHRONIZE  =  {
86+     "cuda" : torch .cuda .synchronize ,
87+     "xpu" : getattr (torch .xpu , "synchronize" , None ),
88+     "cpu" : None ,
89+     "mps" : None ,
90+     "default" : None ,
91+ }
92+ 
93+ 
94+ # This dispatches a defined function according to the accelerator from the function definitions. 
95+ def  _device_agnostic_dispatch (device : str , dispatch_table : Dict [str , Callable ], * args , ** kwargs ):
96+     if  device  not  in dispatch_table :
97+         return  dispatch_table ["default" ](* args , ** kwargs )
98+ 
99+     fn  =  dispatch_table [device ]
100+ 
101+     # Some device agnostic functions return values. Need to guard against 'None' instead at 
102+     # user level 
103+     if  not  callable (fn ):
104+         return  fn 
105+ 
106+     return  fn (* args , ** kwargs )
107+ 
108+ 
109+ # These are callables which automatically dispatch the function specific to the accelerator 
110+ def  backend_manual_seed (device : str , seed : int ):
111+     return  _device_agnostic_dispatch (device , BACKEND_MANUAL_SEED , seed )
112+ 
113+ 
114+ def  backend_synchronize (device : str ):
115+     return  _device_agnostic_dispatch (device , BACKEND_SYNCHRONIZE )
116+ 
117+ 
118+ def  backend_empty_cache (device : str ):
119+     return  _device_agnostic_dispatch (device , BACKEND_EMPTY_CACHE )
120+ 
121+ 
122+ def  backend_device_count (device : str ):
123+     return  _device_agnostic_dispatch (device , BACKEND_DEVICE_COUNT )
124+ 
125+ 
126+ def  backend_reset_peak_memory_stats (device : str ):
127+     return  _device_agnostic_dispatch (device , BACKEND_RESET_PEAK_MEMORY_STATS )
128+ 
129+ 
130+ def  backend_reset_max_memory_allocated (device : str ):
131+     return  _device_agnostic_dispatch (device , BACKEND_RESET_MAX_MEMORY_ALLOCATED )
132+ 
133+ 
134+ def  backend_max_memory_allocated (device : str ):
135+     return  _device_agnostic_dispatch (device , BACKEND_MAX_MEMORY_ALLOCATED )
136+ 
137+ 
138+ # These are callables which return boolean behaviour flags and can be used to specify some 
139+ # device agnostic alternative where the feature is unsupported. 
140+ def  backend_supports_training (device : str ):
141+     if  not  is_torch_available ():
142+         return  False 
143+ 
144+     if  device  not  in BACKEND_SUPPORTS_TRAINING :
145+         device  =  "default" 
146+ 
147+     return  BACKEND_SUPPORTS_TRAINING [device ]
148+ 
149+ 
39150def  randn_tensor (
40151    shape : Union [Tuple , List ],
41152    generator : Optional [Union [List ["torch.Generator" ], "torch.Generator" ]] =  None ,
@@ -197,3 +308,30 @@ def device_synchronize(device_type: Optional[str] = None):
197308        device_type  =  get_device ()
198309    device_mod  =  getattr (torch , device_type , torch .cuda )
199310    device_mod .synchronize ()
311+ 
312+ 
313+ def  enable_full_determinism ():
314+     """ 
315+     Helper function for reproducible behavior during distributed training. See 
316+     - https://pytorch.org/docs/stable/notes/randomness.html for pytorch 
317+     """ 
318+     #  Enable PyTorch deterministic mode. This potentially requires either the environment 
319+     #  variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, 
320+     # depending on the CUDA version, so we set them both here 
321+     os .environ ["CUDA_LAUNCH_BLOCKING" ] =  "1" 
322+     os .environ ["CUBLAS_WORKSPACE_CONFIG" ] =  ":16:8" 
323+     torch .use_deterministic_algorithms (True )
324+ 
325+     # Enable CUDNN deterministic mode 
326+     torch .backends .cudnn .deterministic  =  True 
327+     torch .backends .cudnn .benchmark  =  False 
328+     torch .backends .cuda .matmul .allow_tf32  =  False 
329+ 
330+ 
331+ def  disable_full_determinism ():
332+     os .environ ["CUDA_LAUNCH_BLOCKING" ] =  "0" 
333+     os .environ ["CUBLAS_WORKSPACE_CONFIG" ] =  "" 
334+     torch .use_deterministic_algorithms (False )
335+ 
336+ 
337+ torch_device  =  get_device ()
0 commit comments