@@ -46,11 +46,14 @@ def get_backend_string(cls) -> str:
4646 return "CPU"
4747
4848 @classmethod
49- def change_backend_gempy (cls , engine_backend : AvailableBackends , use_gpu : bool = False , dtype : Optional [str ] = None ):
50- cls ._change_backend (engine_backend , use_pykeops = PYKEOPS , use_gpu = use_gpu , dtype = dtype )
49+ def change_backend_gempy (cls , engine_backend : AvailableBackends , use_gpu : bool = False ,
50+ dtype : Optional [str ] = None , grads :bool = False ):
51+ cls ._change_backend (engine_backend , use_pykeops = PYKEOPS , use_gpu = use_gpu , dtype = dtype ,
52+ grads = grads )
5153
5254 @classmethod
53- def _change_backend (cls , engine_backend : AvailableBackends , use_pykeops : bool = False , use_gpu : bool = True , dtype : Optional [str ] = None ):
55+ def _change_backend (cls , engine_backend : AvailableBackends , use_pykeops : bool = False ,
56+ use_gpu : bool = True , dtype : Optional [str ] = None , grads :bool = False ):
5457 cls .dtype = DEFAULT_TENSOR_DTYPE if dtype is None else dtype
5558 cls .dtype_obj = cls .dtype
5659 match engine_backend :
@@ -99,6 +102,21 @@ def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool =
99102 cls .dtype_obj = pytorch_copy .float32 if cls .dtype == "float32" else pytorch_copy .float64
100103 cls .tensor_types = pytorch_copy .Tensor
101104
105+ torch .set_num_threads (torch .get_num_threads ()) # Use all available threads
106+ cls .COMPUTE_GRADS = grads # Store the grads setting
107+ if grads is False :
108+ cls ._torch_no_grad_context = torch .no_grad ()
109+ cls ._torch_no_grad_context .__enter__ ()
110+ else :
111+ # If there was a previous context, exit it first
112+ if hasattr (cls , '_torch_no_grad_context' ) and cls ._torch_no_grad_context is not None :
113+ try :
114+ cls ._torch_no_grad_context .__exit__ (None , None , None )
115+ except :
116+ pass # Context might already be exited
117+ cls ._torch_no_grad_context = None
118+ torch .set_grad_enabled (True )
119+
102120 cls .use_pykeops = use_pykeops # TODO: Make this compatible with pykeops
103121 if (use_pykeops ):
104122 import pykeops
0 commit comments