Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions flashinfer/fp4_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def fp4_quantize_sm100(
- Scale factors tensor with shape determined by layout and sf_vec_size
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
out_val = torch.empty(
(*input.shape[:-1], input.shape[-1] // 2),
dtype=torch.uint8,
Expand Down Expand Up @@ -480,9 +481,11 @@ def fp4_quantize(

assert input.shape[-1] % sf_vec_size == 0
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
# get input device sm version
major, minor = get_compute_capability(input.device)
# major, minor = get_compute_capability(input.device)
major, minor = get_compute_capability(input.place)
x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100(
input,
global_scale,
Expand Down
42 changes: 31 additions & 11 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import tvm_ffi
import paddle

with paddle.compat.use_torch_proxy_guard(enable=False):
import tvm_ffi

from ..artifacts import ArtifactPath, MetaInfoHash
from ..autotuner import (
Expand Down Expand Up @@ -463,11 +466,15 @@ def __init__(
use_mxfp8_act_scaling,
)

def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype):
dtype_str = str(dtype).split(".", 1)[-1]
return tvm_ffi.dtype(dtype_str)

if instance_key not in MoERunner.runner_dict:
MoERunner.runner_dict[instance_key] = module.init(
x_dtype,
weight_dtype,
output_dtype,
paddle_dtype_to_tvm_ffi_dtype(x_dtype),
paddle_dtype_to_tvm_ffi_dtype(weight_dtype),
paddle_dtype_to_tvm_ffi_dtype(output_dtype),
Comment on lines +475 to +477
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once data-apis/array-api#972 has been implemented, we won't need these three lines anymore.

use_deepseek_fp8_block_scale,
use_w4_group_scaling,
use_mxfp8_act_scaling,
Expand Down Expand Up @@ -565,7 +572,8 @@ def cutlass_fused_moe(
enable_pdl: Optional[bool] = None,
) -> List[torch.Tensor]:
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)

Expand Down Expand Up @@ -623,17 +631,22 @@ def cutlass_fused_moe(
else moe_runner.fused_moe_runner.run_moe
)
num_active_experts_per_node = torch.empty(
(1,), dtype=torch.int32, device=input.device
# (1,), dtype=torch.int32, device=input.device
(1,),
dtype=torch.int32,
device=input.place,
)
experts_to_token_score = torch.empty(
(fc2_expert_weights.shape[0], input.shape[0]),
dtype=torch.float32,
device=input.device,
# device=input.device,
device=input.place,
)
active_expert_global_ids = torch.empty(
(fc2_expert_weights.shape[0],),
dtype=torch.int32,
device=input.device,
# device=input.device,
device=input.place,
)
min_latency_output = (
[
Expand Down Expand Up @@ -897,7 +910,8 @@ def cutlass_fused_moe(
raise NotImplementedError("min latency mode not yet implemented for Blackwell.")

if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)

num_rows = input.shape[0]
if min_latency_mode:
Expand All @@ -906,10 +920,16 @@ def cutlass_fused_moe(
output_shape = (num_rows, hidden_size)

if output is None:
output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
# output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
output = torch.empty(output_shape, dtype=output_dtype, device=input.place)
else:
check_shape_dtype_device(
output, output_shape, output_dtype, input.device, "output"
# output, output_shape, output_dtype, input.device, "output"
output,
output_shape,
output_dtype,
input.place,
"output",
)

major, minor = torch.cuda.get_device_capability()
Expand Down
5 changes: 4 additions & 1 deletion flashinfer/jit/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import dataclasses
import logging
import os
import tvm_ffi
import paddle

with paddle.compat.use_torch_proxy_guard(enable=False):
import tvm_ffi
from contextlib import nullcontext
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
Expand Down
5 changes: 4 additions & 1 deletion flashinfer/jit/cpp_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from pathlib import Path
from typing import List, Optional

import tvm_ffi
import paddle

with paddle.compat.use_torch_proxy_guard(enable=False):
import tvm_ffi
import torch

from . import env as jit_env
Expand Down
21 changes: 14 additions & 7 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

import functools
import math
import os
from enum import Enum
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union

import torch
import torch.version
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

from .jit import gen_jit_spec, env as jit_env

Expand Down Expand Up @@ -231,6 +230,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:

@functools.cache
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
return torch.device.cuda.get_device_capability(device.gpu_device_id())
if device.type != "cuda":
raise ValueError("device must be a cuda device")
return torch.cuda.get_device_capability(device.index)
Expand All @@ -249,7 +249,13 @@ def _check_cached_qkv_data_type(
)


if TorchVersion(torch_version) < TorchVersion("2.4"):
def use_paddle_compatible_api() -> bool:
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]


if use_paddle_compatible_api() or torch.torch_version.TorchVersion(
torch.torch_version.__version__
) < torch.torch_version.TorchVersion("2.4"):

def register_custom_op(
name: str,
Expand Down Expand Up @@ -492,15 +498,16 @@ def check_shape_dtype_device(
expected_device: Optional[torch.device],
name: str,
) -> None:
if expected_shape and x.shape != torch.Size(expected_shape):
if expected_shape and tuple(x.shape) != torch.Size(expected_shape):
raise ValueError(
f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}"
)
if expected_dtype and x.dtype != expected_dtype:
raise ValueError(
f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}"
)
if expected_device and x.device != expected_device:
# if expected_device and x.device != expected_device:
if expected_device and x.place != expected_device:
raise ValueError(
f"Invalid device of {name}: expected {expected_device}, got {x.device}"
)
Expand Down Expand Up @@ -548,8 +555,8 @@ def set_log_level(lvl_str: str) -> None:


def device_support_pdl(device: torch.device) -> bool:
if device.type != "cuda":
return False
# if device.type != "cuda":
# return False
major, _ = get_compute_capability(device)
return major >= 9

Expand Down
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir())


def use_paddle_compatible_api() -> bool:
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]


def write_if_different(path: Path, content: str) -> None:
if path.exists() and path.read_text() == content:
return
Expand Down Expand Up @@ -83,7 +87,6 @@ def generate_build_meta(aot_build_meta: dict) -> None:
cmdclass: Mapping[str, type[setuptools.Command]] = {}
install_requires = [
"numpy",
"torch",
"ninja",
"requests",
"nvidia-ml-py",
Expand All @@ -95,9 +98,17 @@ def generate_build_meta(aot_build_meta: dict) -> None:
"packaging>=24.2",
"nvidia-cudnn-frontend>=1.13.0",
]
if not use_paddle_compatible_api():
install_requires.append("torch")

generate_build_meta({})

if enable_aot:
if use_paddle_compatible_api():
import paddle

paddle.compat.enable_torch_proxy()

import torch

cuda_version = get_cuda_version()
Expand Down
Loading