Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions skyrl-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"tensordict",
"jaxtyping",
"skyrl-gym",
"flash-attn",
"flash-attn; sys_platform == 'linux'", # CUDA-only, skip on macOS
"polars",
"s3fs",
"fastapi",
Expand Down Expand Up @@ -83,14 +83,20 @@ flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"}

[tool.uv.sources]
skyrl-gym = { path = "./skyrl-gym" , editable = true }
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }
# Use CUDA PyTorch index on Linux, PyPI (CPU) on macOS
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
torchvision = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
# We use `flashinfer-jit-cache` to avoid slow JIT compilation on first run.
# Different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang
flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm'" }
# CUDA-only, skip on macOS
flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm' and sys_platform == 'linux'" }
flashinfer-python = [
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'mcore' and extra != 'vllm'" },
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" }
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'mcore' and extra != 'vllm' and sys_platform == 'linux'" },
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm' and sys_platform == 'linux'" }
Comment on lines +98 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and maintainability, you can combine these two flashinfer-python source definitions into a single entry. Since they point to the same URL and the mcore and sglang extras are mutually exclusive according to your [tool.uv].conflicts configuration, their markers can be simplified and merged.

    { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "(extra == 'mcore' or extra == 'sglang') and sys_platform == 'linux'" }

]

[project.optional-dependencies]
Expand Down Expand Up @@ -119,35 +125,35 @@ sandboxes = [
]
vllm = [
"vllm==0.11.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The vllm package is only available for Linux. To achieve the goal of this PR and allow installation on other platforms like macOS, you should add a platform marker to this dependency as well.

    "vllm==0.11.0; sys_platform == 'linux'",

"flash-attn==2.8.3",
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.8.0",
"flashinfer-python",
"flashinfer-jit-cache",
"flashinfer-python; sys_platform == 'linux'",
"flashinfer-jit-cache; sys_platform == 'linux'",
"torchvision"
]
sglang = [
"sglang[srt,openai,torch_memory_saver]==0.4.8.post1", # 0.4.9.post1 causes non-colocate weight broadcast to hang
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The sglang package is only available for Linux. To allow installation on other platforms, you should add a platform marker to this dependency.

    "sglang[srt,openai,torch_memory_saver]==0.4.8.post1; sys_platform == 'linux'",

"flashinfer-python",
"flash-attn==2.8.3",
"flashinfer-python; sys_platform == 'linux'",
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.7.1",
"torchvision",
]
mcore = [
"transformer-engine[pytorch]==2.7.0",
"flash-attn==2.7.4.post1",
"transformer-engine[pytorch]==2.7.0; sys_platform == 'linux'",
"flash-attn==2.7.4.post1; sys_platform == 'linux'",
"vllm==0.10.1.1",
"torch==2.7.1",
"flashinfer-python",
"flashinfer-python; sys_platform == 'linux'",
"torchvision",
"megatron-bridge==0.1.0rc4",
"megatron-core==0.14.0",
]
flashrl = [
# NOTE: Custom vLLM wheel must be installed separately.
# See examples/flash_rl/README.md for installation instructions.
"flash-attn==2.8.3",
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.7.0",
"flashinfer-python",
"flashinfer-python; sys_platform == 'linux'",
"torchvision",
]
miniswe = [
Expand Down
Loading