diff --git a/flytekit/extras/accelerators.py b/flytekit/extras/accelerators.py index 466714b661..8a9a339584 100644 --- a/flytekit/extras/accelerators.py +++ b/flytekit/extras/accelerators.py @@ -142,6 +142,10 @@ def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator: #: `NVIDIA H200 GPU https://www.nvidia.com/en-us/data-center/h200 H200 = GPUAccelerator("nvidia-h200") +#: use this constant to specify that the task should run on an +#: `NVIDIA RTX-PRO-6000 GPU https://www.nvidia.com/en-us/products/workstations/professional-desktop-gpus/rtx-pro-6000/ +RTX_PRO_6000 = GPUAccelerator("nvidia-rtx-pro-6000") + class MultiInstanceGPUAccelerator(BaseAccelerator): """