diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 64da794662..d963cea703 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -180,7 +180,8 @@ def fp4_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) out_val = torch.empty( (*input.shape[:-1], input.shape[-1] // 2), dtype=torch.uint8, @@ -480,9 +481,11 @@ def fp4_quantize( assert input.shape[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) # get input device sm version - major, minor = get_compute_capability(input.device) + # major, minor = get_compute_capability(input.device) + major, minor = get_compute_capability(input.place) x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( input, global_scale, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index f61bc55250..dbb9f69c88 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,7 +20,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -import tvm_ffi +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from ..artifacts import ArtifactPath, MetaInfoHash from ..autotuner import ( @@ -463,11 +466,15 @@ def __init__( use_mxfp8_act_scaling, ) + def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype): + dtype_str = str(dtype).split(".", 1)[-1] + return tvm_ffi.dtype(dtype_str) + if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = module.init( - x_dtype, - weight_dtype, - output_dtype, + paddle_dtype_to_tvm_ffi_dtype(x_dtype), + paddle_dtype_to_tvm_ffi_dtype(weight_dtype), + paddle_dtype_to_tvm_ffi_dtype(output_dtype), use_deepseek_fp8_block_scale, use_w4_group_scaling, use_mxfp8_act_scaling, @@ -565,7 +572,8 @@ def cutlass_fused_moe( enable_pdl: Optional[bool] = None, ) -> List[torch.Tensor]: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -623,17 +631,22 @@ def cutlass_fused_moe( else moe_runner.fused_moe_runner.run_moe ) num_active_experts_per_node = torch.empty( - (1,), dtype=torch.int32, device=input.device + # (1,), dtype=torch.int32, device=input.device + (1,), + dtype=torch.int32, + device=input.place, ) experts_to_token_score = torch.empty( (fc2_expert_weights.shape[0], input.shape[0]), dtype=torch.float32, - device=input.device, + # device=input.device, + device=input.place, ) active_expert_global_ids = torch.empty( (fc2_expert_weights.shape[0],), dtype=torch.int32, - device=input.device, + # device=input.device, + device=input.place, ) min_latency_output = ( [ @@ -897,7 +910,8 @@ def cutlass_fused_moe( raise NotImplementedError("min latency mode not yet implemented for Blackwell.") if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) num_rows = input.shape[0] if min_latency_mode: @@ -906,10 +920,16 @@ def cutlass_fused_moe( output_shape = (num_rows, hidden_size) if output is None: - output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + # output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + output = torch.empty(output_shape, dtype=output_dtype, device=input.place) else: check_shape_dtype_device( - output, output_shape, output_dtype, input.device, "output" + # output, output_shape, output_dtype, input.device, "output" + output, + output_shape, + output_dtype, + input.place, + "output", ) major, minor = torch.cuda.get_device_capability() diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 09fac79d19..fd0e6ac068 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,7 +1,10 @@ import dataclasses import logging import os -import tvm_ffi +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from contextlib import nullcontext from pathlib import Path from typing import Dict, List, Optional, Sequence, Union diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 73470c58da..1c9d8eae08 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -10,7 +10,10 @@ from pathlib import Path from typing import List, Optional -import tvm_ffi +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi import torch from . import env as jit_env diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 06deaf55ea..1227597fdc 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -16,13 +16,12 @@ import functools import math +import os from enum import Enum from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import torch import torch.version -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version from .jit import gen_jit_spec, env as jit_env @@ -231,6 +230,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @functools.cache def get_compute_capability(device: torch.device) -> Tuple[int, int]: + return torch.device.cuda.get_device_capability(device.gpu_device_id()) if device.type != "cuda": raise ValueError("device must be a cuda device") return torch.cuda.get_device_capability(device.index) @@ -249,7 +249,13 @@ def _check_cached_qkv_data_type( ) -if TorchVersion(torch_version) < TorchVersion("2.4"): +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + +if use_paddle_compatible_api() or torch.torch_version.TorchVersion( + torch.torch_version.__version__ +) < torch.torch_version.TorchVersion("2.4"): def register_custom_op( name: str, @@ -492,7 +498,7 @@ def check_shape_dtype_device( expected_device: Optional[torch.device], name: str, ) -> None: - if expected_shape and x.shape != torch.Size(expected_shape): + if expected_shape and tuple(x.shape) != torch.Size(expected_shape): raise ValueError( f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}" ) @@ -500,7 +506,8 @@ def check_shape_dtype_device( raise ValueError( f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}" ) - if expected_device and x.device != expected_device: + # if expected_device and x.device != expected_device: + if expected_device and x.place != expected_device: raise ValueError( f"Invalid device of {name}: expected {expected_device}, got {x.device}" ) @@ -548,8 +555,8 @@ def set_log_level(lvl_str: str) -> None: def device_support_pdl(device: torch.device) -> bool: - if device.type != "cuda": - return False + # if device.type != "cuda": + # return False major, _ = get_compute_capability(device) return major >= 9 diff --git a/setup.py b/setup.py index f3e1d991c5..8e0daafb3b 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,10 @@ enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir()) +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + def write_if_different(path: Path, content: str) -> None: if path.exists() and path.read_text() == content: return @@ -83,7 +87,6 @@ def generate_build_meta(aot_build_meta: dict) -> None: cmdclass: Mapping[str, type[setuptools.Command]] = {} install_requires = [ "numpy", - "torch", "ninja", "requests", "nvidia-ml-py", @@ -95,9 +98,17 @@ def generate_build_meta(aot_build_meta: dict) -> None: "packaging>=24.2", "nvidia-cudnn-frontend>=1.13.0", ] +if not use_paddle_compatible_api(): + install_requires.append("torch") + generate_build_meta({}) if enable_aot: + if use_paddle_compatible_api(): + import paddle + + paddle.compat.enable_torch_proxy() + import torch cuda_version = get_cuda_version()