diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index d79fa785f6..1524d6a8ab 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -65,12 +65,16 @@ class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): """ _PYTORCH_TASK_TYPE = "pytorch" + _PYTORCH_TASK_TYPE_STANDALONE = "python-task" def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): + + task_type = self._PYTORCH_TASK_TYPE_STANDALONEE if task_config.num_workers == 0 else self._PYTORCH_TASK_TYPE + super().__init__( task_config, task_function, - task_type=self._PYTORCH_TASK_TYPE, + task_type=task_type, **kwargs, ) diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 00eb6c0953..929811cd53 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -31,3 +31,14 @@ def my_pytorch_task(x: int, y: str) -> int: assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") assert my_pytorch_task.task_type == "pytorch" + +def test_zero_worker(): + @task( + task_config=PyTorch(num_workers=0), + cache=True, + cache_version="1", + requests=Resources(cpu="1"), + ) + def my_pytorch_task(x: int, y: str) -> int: + return x + assert my_pytorch_task.task_type == "python-task"