@@ -492,58 +492,54 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
492
492
"PYTORCH" : PYTORCH ,
493
493
}
494
494
495
- instantiated_default_mode = None
495
+ _CACHED_RUNTIME_MODES : dict [ str , Mode ] = {}
496
496
497
497
498
498
def get_mode (orig_string ):
499
499
if orig_string is None :
500
500
string = config .mode
501
501
else :
502
502
string = orig_string
503
+
503
504
if not isinstance (string , str ):
504
505
return string # it is hopefully already a mode...
505
506
506
- global instantiated_default_mode
507
- # The default mode is cached. However, config.mode can change
508
- # If instantiated_default_mode has the right class, use it.
509
- if orig_string is None and instantiated_default_mode :
510
- if string in predefined_modes :
511
- default_mode_class = predefined_modes [string ].__class__ .__name__
512
- else :
513
- default_mode_class = string
514
- if instantiated_default_mode .__class__ .__name__ == default_mode_class :
515
- return instantiated_default_mode
516
-
517
- if string in ("Mode" , "DebugMode" , "NanGuardMode" ):
518
- if string == "DebugMode" :
519
- # need to import later to break circular dependency.
520
- from .debugmode import DebugMode
521
-
522
- # DebugMode use its own linker.
523
- ret = DebugMode (optimizer = config .optimizer )
524
- elif string == "NanGuardMode" :
525
- # need to import later to break circular dependency.
526
- from .nanguardmode import NanGuardMode
527
-
528
- # NanGuardMode use its own linker.
529
- ret = NanGuardMode (True , True , True , optimizer = config .optimizer )
530
- else :
531
- # TODO: Can't we look up the name and invoke it rather than using eval here?
532
- ret = eval (string + "(linker=config.linker, optimizer=config.optimizer)" )
533
- elif string in predefined_modes :
534
- ret = predefined_modes [string ]
535
- else :
536
- raise Exception (f"No predefined mode exist for string: { string } " )
507
+ # Keep the original string for error messages
508
+ upper_string = string .upper ()
537
509
538
- if orig_string is None :
539
- # Build and cache the default mode
540
- if config .optimizer_excluding :
541
- ret = ret .excluding (* config .optimizer_excluding .split (":" ))
542
- if config .optimizer_including :
543
- ret = ret .including (* config .optimizer_including .split (":" ))
544
- if config .optimizer_requiring :
545
- ret = ret .requiring (* config .optimizer_requiring .split (":" ))
546
- instantiated_default_mode = ret
510
+ if upper_string in predefined_modes :
511
+ return predefined_modes [upper_string ]
512
+
513
+ global _CACHED_RUNTIME_MODES
514
+
515
+ if upper_string in _CACHED_RUNTIME_MODES :
516
+ return _CACHED_RUNTIME_MODES [upper_string ]
517
+
518
+ # Need to define the mode for the first time
519
+ if upper_string == "MODE" :
520
+ ret = Mode (linker = config .linker , optimizer = config .optimizer )
521
+ elif upper_string in ("DEBUGMODE" , "DEBUG_MODE" ):
522
+ from pytensor .compile .debugmode import DebugMode
523
+
524
+ # DebugMode use its own linker.
525
+ ret = DebugMode (optimizer = config .optimizer )
526
+ elif upper_string == "NANGUARDMODE" :
527
+ from pytensor .compile .nanguardmode import NanGuardMode
528
+
529
+ # NanGuardMode use its own linker.
530
+ ret = NanGuardMode (True , True , True , optimizer = config .optimizer )
531
+
532
+ else :
533
+ raise ValueError (f"No predefined mode exist for string: { string } " )
534
+
535
+ if config .optimizer_excluding :
536
+ ret = ret .excluding (* config .optimizer_excluding .split (":" ))
537
+ if config .optimizer_including :
538
+ ret = ret .including (* config .optimizer_including .split (":" ))
539
+ if config .optimizer_requiring :
540
+ ret = ret .requiring (* config .optimizer_requiring .split (":" ))
541
+ # Cache the mode for next time
542
+ _CACHED_RUNTIME_MODES [upper_string ] = ret
547
543
548
544
return ret
549
545
0 commit comments