2020 _set_module_training_mode ,
2121)
2222from torchtnt .framework .callback import Callback
23+ from torchtnt .framework .callbacks .base_checkpointer import BaseCheckpointer
2324from torchtnt .framework .state import ActivePhase , EntryPoint , PhaseState , State
2425from torchtnt .framework .unit import TPredictData , TPredictUnit
2526from torchtnt .framework .utils import get_timing_context
@@ -80,7 +81,10 @@ def predict(
8081 call on_predict_end on unit first and then callbacks
8182 """
8283 _log_api_usage ("predict" )
83- callback_handler = CallbackHandler (callbacks or [])
84+ callbacks = callbacks or []
85+ callback_handler = CallbackHandler (callbacks )
86+ checkpoint_cb_exists = any (isinstance (cb , BaseCheckpointer ) for cb in callbacks )
87+
8488 state = State (
8589 entry_point = EntryPoint .PREDICT ,
8690 predict_state = PhaseState (
@@ -90,7 +94,13 @@ def predict(
9094 timer = timer ,
9195 )
9296 try :
93- _predict_impl (state , predict_unit , callback_handler )
97+ # all_gather using inference_mode with gloo backend is not supported. Since this collective
98+ # is necessary for checkpointing, we need to use torch.no_grad instead.
99+ # TODO: remove this once all_gather is supported in inference_mode.
100+ inference_ctx = torch .no_grad if checkpoint_cb_exists else torch .inference_mode
101+ with inference_ctx ():
102+ _predict_impl (state , predict_unit , callback_handler )
103+
94104 logger .info ("Finished predict" )
95105 if state .timer :
96106 logger .info (get_timer_summary (state .timer ))
@@ -104,7 +114,6 @@ def predict(
104114 raise e
105115
106116
107- @torch .inference_mode ()
108117def _predict_impl (
109118 state : State ,
110119 predict_unit : TPredictUnit ,
0 commit comments