@@ -277,7 +277,7 @@ def get_class_that_defined_method(f):
277277 return out
278278
279279 @classmethod
280- def func_name (cls , fn ):
280+ def get_func_name (cls , fn ):
281281 # produces a name like torchrl.module.Class.method or torchrl.module.function
282282 first = str (fn ).split ("." )[0 ][len ("<function " ) :]
283283 last = str (fn ).split ("." )[1 :]
@@ -300,10 +300,10 @@ def _get_cls(self, fn):
300300
301301 def module_set (self ):
302302 """Sets the function in its module, if it exists already."""
303- prev_setter = type (self )._implementations .get (self .func_name (self .fn ), None )
303+ prev_setter = type (self )._implementations .get (self .get_func_name (self .fn ), None )
304304 if prev_setter is not None :
305305 prev_setter .do_set = False
306- type(self )._implementations [self .func_name (self .fn )] = self
306+ type(self )._implementations [self .get_func_name (self .fn )] = self
307307 cls = self .get_class_that_defined_method (self .fn )
308308 if cls is not None :
309309 if cls .__class__ .__name__ == "function" :
@@ -329,11 +329,32 @@ def import_module(cls, module_name: Union[Callable, str]) -> str:
329329 module = module_name ()
330330 return module .__version__
331331
332+ _lazy_impl = collections .defaultdict (list )
333+
334+ def _delazify (self , func_name ):
335+ for local_call in implement_for ._lazy_impl [func_name ]:
336+ out = local_call ()
337+ return out
338+
332339 def __call__ (self , fn ):
340+ # function names are unique
341+ self .func_name = self .get_func_name (fn )
333342 self .fn = fn
343+ implement_for ._lazy_impl [self .func_name ].append (self ._call )
344+
345+ @wraps (fn )
346+ def _lazy_call_fn (* args , ** kwargs ):
347+ # first time we call the function, we also do the replacement.
348+ # This will cause the imports to occur only during the first call to fn
349+ return self ._delazify (self .func_name )(* args , ** kwargs )
350+
351+ return _lazy_call_fn
352+
353+ def _call (self ):
334354
335355 # If the module is missing replace the function with the mock.
336- func_name = self .func_name (self .fn )
356+ fn = self .fn
357+ func_name = self .func_name
337358 implementations = implement_for ._implementations
338359
339360 @wraps (fn )
0 commit comments