diff --git a/torchx/components/utils.py b/torchx/components/utils.py index 62bf33664..e64069a14 100644 --- a/torchx/components/utils.py +++ b/torchx/components/utils.py @@ -83,6 +83,7 @@ def sh( env: Optional[Dict[str, str]] = None, max_retries: int = 0, mounts: Optional[List[str]] = None, + entrypoint: Optional[str] = None, ) -> specs.AppDef: """ Runs the provided command via sh. Currently sh does not support @@ -100,6 +101,7 @@ def sh( max_retries: the number of scheduler retries allowed mounts: mounts to mount into the worker environment/container (ex. type=,src=/host,dst=/job[,readonly]). See scheduler documentation for more info. + entrypoint: the entrypoint to use for the command (defaults to sh) """ escaped_args = " ".join(shlex.quote(arg) for arg in args) @@ -113,7 +115,7 @@ def sh( specs.Role( name="sh", image=image, - entrypoint="sh", + entrypoint=entrypoint or "sh", args=["-c", escaped_args], num_replicas=num_replicas, resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),