Skip to content
Draft
3 changes: 2 additions & 1 deletion flashinfer/compilation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self):
self.TARGET_CUDA_ARCHS.add((int(major), minor))
else:
try:
for device in range(torch.cuda.device_count()):
# for device in range(torch.cuda.device_count()):
for device in range(1):
major, minor = torch.cuda.get_device_capability(device)
if major >= 9:
minor = str(minor) + "a"
Expand Down
9 changes: 6 additions & 3 deletions flashinfer/fp4_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def fp4_quantize_sm100(
- Scale factors tensor with shape determined by layout and sf_vec_size
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
return module.fp4_quantize(
input,
global_scale,
Expand Down Expand Up @@ -355,9 +356,11 @@ def fp4_quantize(

assert input.shape[-1] % sf_vec_size == 0
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
# get input device sm version
major, minor = get_compute_capability(input.device)
# major, minor = get_compute_capability(input.device)
major, minor = get_compute_capability(input.place)
x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100(
input,
global_scale,
Expand Down
16 changes: 12 additions & 4 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,8 @@ def cutlass_fused_moe(
enable_pdl: Optional[bool] = None,
) -> List[torch.Tensor]:
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)

Expand Down Expand Up @@ -868,7 +869,8 @@ def cutlass_fused_moe(
raise NotImplementedError("min latency mode not yet implemented for Blackwell.")

if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)

num_rows = input.shape[0]
if min_latency_mode:
Expand All @@ -877,10 +879,16 @@ def cutlass_fused_moe(
output_shape = (num_rows, hidden_size)

if output is None:
output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
# output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
output = torch.empty(output_shape, dtype=output_dtype, device=input.place)
else:
check_shape_dtype_device(
output, output_shape, output_dtype, input.device, "output"
# output, output_shape, output_dtype, input.device, "output"
output,
output_shape,
output_dtype,
input.place,
"output",
)

major, minor = torch.cuda.get_device_capability()
Expand Down
60 changes: 50 additions & 10 deletions flashinfer/jit/cpp_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@

import torch
from torch.utils.cpp_extension import (
_TORCH_PATH,
CUDA_HOME,
_get_num_workers,
_get_pybind11_abi_build_flags,
)

from . import env as jit_env
from ..utils import use_paddle_compatible_api
from ..compilation_context import CompilationContext

if use_paddle_compatible_api():
_TORCH_PATH = torch.__path__[0]
else:
from torch.utils.cpp_extension import _TORCH_PATH # type: ignore[no-redef]


@functools.cache
def get_cuda_path() -> str:
Expand Down Expand Up @@ -75,13 +80,26 @@ def generate_ninja_build_for_op(
) -> str:
system_includes = [
sysconfig.get_path("include"),
"$torch_home/include",
"$torch_home/include/torch/csrc/api/include",
"$cuda_home/include",
"$cuda_home/include/cccl",
jit_env.FLASHINFER_INCLUDE_DIR.resolve(),
jit_env.FLASHINFER_CSRC_DIR.resolve(),
]
if use_paddle_compatible_api():
system_includes.extend(
[
"$torch_home/include",
"$torch_home/include/paddle/phi/api/include/compat",
"$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include",
]
)
else:
system_includes.extend(
[
"$torch_home/include",
"$torch_home/include/torch/csrc/api/include",
]
)
system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS]
system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve())

Expand All @@ -90,6 +108,8 @@ def generate_ninja_build_for_op(
"-DTORCH_API_INCLUDE_EXTENSION_H",
"-DPy_LIMITED_API=0x03090000",
]
if use_paddle_compatible_api():
common_cflags.append("-DPADDLE_WITH_CUDA")
common_cflags += _get_pybind11_abi_build_flags()
common_cflags += _get_glibcxx_abi_build_flags()
if extra_include_dirs is not None:
Expand Down Expand Up @@ -144,15 +164,35 @@ def generate_ninja_build_for_op(

ldflags = [
"-shared",
"-L$torch_home/lib",
"-L$cuda_home/lib64",
"-lc10",
"-lc10_cuda",
"-ltorch_cpu",
"-ltorch_cuda",
"-ltorch",
"-lcudart",
]
if use_paddle_compatible_api():
ldflags.extend(
[
"-shared",
"-L$torch_home/libs",
"-L$torch_home/base",
"-L$cuda_home/lib64",
"-lpaddle",
"-lphi",
"-lphi_core",
"-lphi_gpu",
"-lcommon",
"-lcudart",
]
)
else:
ldflags.extend(
[
"-L$torch_home/lib",
"-L$cuda_home/lib64",
"-lc10",
"-lc10_cuda",
"-ltorch_cpu",
"-ltorch_cuda",
"-ltorch",
]
)

env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS")
if env_extra_ldflags:
Expand Down
12 changes: 6 additions & 6 deletions flashinfer/jit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
torch.int8: "int8_t",
torch.uint8: "uint8_t",
torch.int32: "int32_t",
torch.uint32: "uint32_t",
# torch.uint32: "uint32_t",
torch.int64: "int64_t",
torch.uint64: "uint64_t",
# torch.uint64: "uint64_t",
}

dtype_cutlass_map = {
Expand All @@ -51,9 +51,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
torch.int8: "cutlass::int8_t",
torch.uint8: "cutlass::uint8_t",
torch.int32: "cutlass::int32_t",
torch.uint32: "cutlass::uint32_t",
# torch.uint32: "cutlass::uint32_t",
torch.int64: "cutlass::int64_t",
torch.uint64: "cutlass::uint64_t",
# torch.uint64: "cutlass::uint64_t",
}

filename_safe_dtype_map = {
Expand All @@ -64,9 +64,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
torch.int8: "i8",
torch.uint8: "u8",
torch.int32: "i32",
torch.uint32: "u32",
# torch.uint32: "u32",
torch.int64: "i64",
torch.uint64: "u64",
# torch.uint64: "u64",
}

pos_encoding_mode_literal = {
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
limitations under the License.
"""

from __future__ import annotations # for torch.Generator

import functools
from types import SimpleNamespace
from typing import Optional, Union
Expand Down
33 changes: 21 additions & 12 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@

import torch
import torch.version
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

from .jit import gen_jit_spec, env as jit_env
import flashinfer

IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"

Expand Down Expand Up @@ -222,6 +220,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:

@functools.cache
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
return torch.device.cuda.get_device_capability(device.gpu_device_id())
if device.type != "cuda":
raise ValueError("device must be a cuda device")
return torch.cuda.get_device_capability(device.index)
Expand All @@ -240,7 +239,16 @@ def _check_cached_qkv_data_type(
)


if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"):
def use_paddle_compatible_api() -> bool:
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]


if (
use_paddle_compatible_api()
or IS_BUILDING_DOCS
or torch.torch_version.TorchVersion(torch.torch_version.__version__)
< torch.torch_version.TorchVersion("2.4")
):

def register_custom_op(
name: str,
Expand Down Expand Up @@ -477,29 +485,30 @@ def check_shape_dtype_device(
expected_device: Optional[torch.device],
name: str,
) -> None:
if expected_shape and x.shape != torch.Size(expected_shape):
if expected_shape and tuple(x.shape) != torch.Size(expected_shape):
raise ValueError(
f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}"
)
if expected_dtype and x.dtype != expected_dtype:
raise ValueError(
f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}"
)
if expected_device and x.device != expected_device:
# if expected_device and x.device != expected_device:
if expected_device and x.place != expected_device:
raise ValueError(
f"Invalid device of {name}: expected {expected_device}, got {x.device}"
)


def gen_logging_module():
return gen_jit_spec(
return flashinfer.jit.gen_jit_spec(
"logging",
[
jit_env.FLASHINFER_CSRC_DIR / "logging.cc",
flashinfer.jit.env.FLASHINFER_CSRC_DIR / "logging.cc",
],
extra_include_paths=[
jit_env.SPDLOG_INCLUDE_DIR,
jit_env.FLASHINFER_INCLUDE_DIR,
flashinfer.jit.env.SPDLOG_INCLUDE_DIR,
flashinfer.jit.env.FLASHINFER_INCLUDE_DIR,
],
)

Expand Down Expand Up @@ -533,8 +542,8 @@ def set_log_level(lvl_str: str) -> None:


def device_support_pdl(device: torch.device) -> bool:
if device.type != "cuda":
return False
# if device.type != "cuda":
# return False
major, _ = get_compute_capability(device)
return major >= 9

Expand Down
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir())


def use_paddle_compatible_api() -> bool:
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]


def write_if_different(path: Path, content: str) -> None:
if path.exists() and path.read_text() == content:
return
Expand All @@ -55,7 +59,6 @@ def generate_build_meta(aot_build_meta: dict) -> None:
cmdclass: Mapping[str, type[setuptools.Command]] = {}
install_requires = [
"numpy",
"torch",
"ninja",
"requests",
"pynvml",
Expand All @@ -66,9 +69,17 @@ def generate_build_meta(aot_build_meta: dict) -> None:
"packaging>=24.2",
"nvidia-cudnn-frontend>=1.13.0",
]
if not use_paddle_compatible_api():
install_requires.append("torch")

generate_build_meta({})

if enable_aot:
if use_paddle_compatible_api():
import paddle

paddle.compat.enable_torch_proxy()

import torch
import torch.utils.cpp_extension as torch_cpp_ext
from packaging.version import Version
Expand Down