diff --git a/torchx/util/session.py b/torchx/util/session.py index f03b2b10b..6068169fb 100644 --- a/torchx/util/session.py +++ b/torchx/util/session.py @@ -7,9 +7,12 @@ # pyre-strict +import os import uuid from typing import Optional +TORCHX_INTERNAL_SESSION_ID = "TORCHX_INTERNAL_SESSION_ID" + CURRENT_SESSION_ID: Optional[str] = None @@ -22,6 +25,10 @@ def get_session_id_or_create_new() -> str: global CURRENT_SESSION_ID if CURRENT_SESSION_ID: return CURRENT_SESSION_ID + env_session_id = os.getenv(TORCHX_INTERNAL_SESSION_ID) + if env_session_id: + CURRENT_SESSION_ID = env_session_id + return CURRENT_SESSION_ID session_id = str(uuid.uuid4()) CURRENT_SESSION_ID = session_id return session_id