diff --git a/flashinfer/compilation_context.py b/flashinfer/compilation_context.py index 5d24643f55..5e7078c5f6 100644 --- a/flashinfer/compilation_context.py +++ b/flashinfer/compilation_context.py @@ -42,7 +42,8 @@ def __init__(self): self.TARGET_CUDA_ARCHS.add((int(major), minor)) else: try: - for device in range(torch.cuda.device_count()): + # for device in range(torch.cuda.device_count()): + for device in range(1): major, minor = torch.cuda.get_device_capability(device) if major >= 9: minor = str(minor) + "a" diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index fc6bd96610..2ec4ae4be0 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -174,7 +174,8 @@ def fp4_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) return module.fp4_quantize( input, global_scale, @@ -355,9 +356,11 @@ def fp4_quantize( assert input.shape[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) # get input device sm version - major, minor = get_compute_capability(input.device) + # major, minor = get_compute_capability(input.device) + major, minor = get_compute_capability(input.place) x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( input, global_scale, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index ad6169c515..85dff5723d 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -568,7 +568,8 @@ def cutlass_fused_moe( enable_pdl: Optional[bool] = None, ) -> List[torch.Tensor]: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -868,7 +869,8 @@ def cutlass_fused_moe( raise NotImplementedError("min latency mode not yet implemented for Blackwell.") if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) num_rows = input.shape[0] if min_latency_mode: @@ -877,10 +879,16 @@ def cutlass_fused_moe( output_shape = (num_rows, hidden_size) if output is None: - output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + # output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + output = torch.empty(output_shape, dtype=output_dtype, device=input.place) else: check_shape_dtype_device( - output, output_shape, output_dtype, input.device, "output" + # output, output_shape, output_dtype, input.device, "output" + output, + output_shape, + output_dtype, + input.place, + "output", ) major, minor = torch.cuda.get_device_capability() diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index a0777b9e37..73277a76f4 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -12,15 +12,20 @@ import torch from torch.utils.cpp_extension import ( - _TORCH_PATH, CUDA_HOME, _get_num_workers, _get_pybind11_abi_build_flags, ) from . import env as jit_env +from ..utils import use_paddle_compatible_api from ..compilation_context import CompilationContext +if use_paddle_compatible_api(): + _TORCH_PATH = torch.__path__[0] +else: + from torch.utils.cpp_extension import _TORCH_PATH # type: ignore[no-redef] + @functools.cache def get_cuda_path() -> str: @@ -75,13 +80,26 @@ def generate_ninja_build_for_op( ) -> str: system_includes = [ sysconfig.get_path("include"), - "$torch_home/include", - "$torch_home/include/torch/csrc/api/include", "$cuda_home/include", "$cuda_home/include/cccl", jit_env.FLASHINFER_INCLUDE_DIR.resolve(), jit_env.FLASHINFER_CSRC_DIR.resolve(), ] + if use_paddle_compatible_api(): + system_includes.extend( + [ + "$torch_home/include", + "$torch_home/include/paddle/phi/api/include/compat", + "$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include", + ] + ) + else: + system_includes.extend( + [ + "$torch_home/include", + "$torch_home/include/torch/csrc/api/include", + ] + ) system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS] system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve()) @@ -90,6 +108,8 @@ def generate_ninja_build_for_op( "-DTORCH_API_INCLUDE_EXTENSION_H", "-DPy_LIMITED_API=0x03090000", ] + if use_paddle_compatible_api(): + common_cflags.append("-DPADDLE_WITH_CUDA") common_cflags += _get_pybind11_abi_build_flags() common_cflags += _get_glibcxx_abi_build_flags() if extra_include_dirs is not None: @@ -144,15 +164,35 @@ def generate_ninja_build_for_op( ldflags = [ "-shared", - "-L$torch_home/lib", - "-L$cuda_home/lib64", - "-lc10", - "-lc10_cuda", - "-ltorch_cpu", - "-ltorch_cuda", - "-ltorch", "-lcudart", ] + if use_paddle_compatible_api(): + ldflags.extend( + [ + "-shared", + "-L$torch_home/libs", + "-L$torch_home/base", + "-L$cuda_home/lib64", + "-lpaddle", + "-lphi", + "-lphi_core", + "-lphi_gpu", + "-lcommon", + "-lcudart", + ] + ) + else: + ldflags.extend( + [ + "-L$torch_home/lib", + "-L$cuda_home/lib64", + "-lc10", + "-lc10_cuda", + "-ltorch_cpu", + "-ltorch_cuda", + "-ltorch", + ] + ) env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS") if env_extra_ldflags: diff --git a/flashinfer/jit/utils.py b/flashinfer/jit/utils.py index 4e19212e14..5f621f83e6 100644 --- a/flashinfer/jit/utils.py +++ b/flashinfer/jit/utils.py @@ -38,9 +38,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.int8: "int8_t", torch.uint8: "uint8_t", torch.int32: "int32_t", - torch.uint32: "uint32_t", + # torch.uint32: "uint32_t", torch.int64: "int64_t", - torch.uint64: "uint64_t", + # torch.uint64: "uint64_t", } dtype_cutlass_map = { @@ -51,9 +51,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.int8: "cutlass::int8_t", torch.uint8: "cutlass::uint8_t", torch.int32: "cutlass::int32_t", - torch.uint32: "cutlass::uint32_t", + # torch.uint32: "cutlass::uint32_t", torch.int64: "cutlass::int64_t", - torch.uint64: "cutlass::uint64_t", + # torch.uint64: "cutlass::uint64_t", } filename_safe_dtype_map = { @@ -64,9 +64,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.int8: "i8", torch.uint8: "u8", torch.int32: "i32", - torch.uint32: "u32", + # torch.uint32: "u32", torch.int64: "i64", - torch.uint64: "u64", + # torch.uint64: "u64", } pos_encoding_mode_literal = { diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 4cd7e5bd5a..f45ddabd3f 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -14,6 +14,8 @@ limitations under the License. """ +from __future__ import annotations # for torch.Generator + import functools from types import SimpleNamespace from typing import Optional, Union diff --git a/flashinfer/utils.py b/flashinfer/utils.py index ab1a1fa71e..3f2ee016ba 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -22,10 +22,8 @@ import torch import torch.version -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version -from .jit import gen_jit_spec, env as jit_env +import flashinfer IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1" @@ -222,6 +220,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @functools.cache def get_compute_capability(device: torch.device) -> Tuple[int, int]: + return torch.device.cuda.get_device_capability(device.gpu_device_id()) if device.type != "cuda": raise ValueError("device must be a cuda device") return torch.cuda.get_device_capability(device.index) @@ -240,7 +239,16 @@ def _check_cached_qkv_data_type( ) -if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"): +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + +if ( + use_paddle_compatible_api() + or IS_BUILDING_DOCS + or torch.torch_version.TorchVersion(torch.torch_version.__version__) + < torch.torch_version.TorchVersion("2.4") +): def register_custom_op( name: str, @@ -477,7 +485,7 @@ def check_shape_dtype_device( expected_device: Optional[torch.device], name: str, ) -> None: - if expected_shape and x.shape != torch.Size(expected_shape): + if expected_shape and tuple(x.shape) != torch.Size(expected_shape): raise ValueError( f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}" ) @@ -485,21 +493,22 @@ def check_shape_dtype_device( raise ValueError( f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}" ) - if expected_device and x.device != expected_device: + # if expected_device and x.device != expected_device: + if expected_device and x.place != expected_device: raise ValueError( f"Invalid device of {name}: expected {expected_device}, got {x.device}" ) def gen_logging_module(): - return gen_jit_spec( + return flashinfer.jit.gen_jit_spec( "logging", [ - jit_env.FLASHINFER_CSRC_DIR / "logging.cc", + flashinfer.jit.env.FLASHINFER_CSRC_DIR / "logging.cc", ], extra_include_paths=[ - jit_env.SPDLOG_INCLUDE_DIR, - jit_env.FLASHINFER_INCLUDE_DIR, + flashinfer.jit.env.SPDLOG_INCLUDE_DIR, + flashinfer.jit.env.FLASHINFER_INCLUDE_DIR, ], ) @@ -533,8 +542,8 @@ def set_log_level(lvl_str: str) -> None: def device_support_pdl(device: torch.device) -> bool: - if device.type != "cuda": - return False + # if device.type != "cuda": + # return False major, _ = get_compute_capability(device) return major >= 9 diff --git a/setup.py b/setup.py index 3949ea8ce0..1f865ca155 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,10 @@ enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir()) +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + def write_if_different(path: Path, content: str) -> None: if path.exists() and path.read_text() == content: return @@ -55,7 +59,6 @@ def generate_build_meta(aot_build_meta: dict) -> None: cmdclass: Mapping[str, type[setuptools.Command]] = {} install_requires = [ "numpy", - "torch", "ninja", "requests", "pynvml", @@ -66,9 +69,17 @@ def generate_build_meta(aot_build_meta: dict) -> None: "packaging>=24.2", "nvidia-cudnn-frontend>=1.13.0", ] +if not use_paddle_compatible_api(): + install_requires.append("torch") + generate_build_meta({}) if enable_aot: + if use_paddle_compatible_api(): + import paddle + + paddle.compat.enable_torch_proxy() + import torch import torch.utils.cpp_extension as torch_cpp_ext from packaging.version import Version