Skip to content

Commit fe1ea95

Browse files
pilluminaTcc0403
andauthored
[refactor] decoupling ops implementations for different vendors (#973)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Related issue: #965 Have mocked following kernel modules for testing: - `/backends/_ascend/ops/rope.py` ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> ![1b9d022b343898aa15077b5dec557da4](https://github.com/user-attachments/assets/dee75885-19aa-48e3-9bab-f904a73cca65) test cases in `transformer/test_rope.py` could capture stdout message: "Using NPU LigerRopeFunction" means replacing rope kernel successfully. - Hardware Type: <> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Tcc0403 <[email protected]>
1 parent 49e6353 commit fe1ea95

31 files changed

+432
-46
lines changed

src/liger_kernel/ops/__init__.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
Liger-Kernel operators with automatic vendor-specific replacement.
3+
4+
This module provides two ways to import operators:
5+
6+
1. Import from this package (recommended for Function classes):
7+
from liger_kernel.ops import LigerGELUMulFunction
8+
9+
This automatically uses vendor-specific implementation if available.
10+
11+
2. Import from submodules (for kernel functions or specific access):
12+
from liger_kernel.ops.geglu import geglu_forward, geglu_backward
13+
14+
This always uses the default implementation (no auto-replacement).
15+
16+
The replacement mechanism:
17+
1. Default implementations are imported from individual modules (e.g., geglu.py)
18+
2. On module load, device is detected via infer_device()
19+
3. If running on a supported vendor device (npu, xpu, etc.), the default
20+
implementations are replaced with vendor-specific ones
21+
4. All subsequent imports from this package get the replaced versions
22+
23+
Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...)
24+
are NOT affected by the replacement mechanism.
25+
"""
26+
27+
# =============================================================================
28+
# Import default implementations
29+
# Both Function classes and kernel functions are imported here.
30+
# All of these can be replaced by vendor-specific implementations.
31+
# =============================================================================
32+
33+
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401
34+
from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401
35+
from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401
36+
from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401
37+
from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401
38+
from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401
39+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401
40+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401
41+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401
42+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401
43+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401
44+
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
45+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
46+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
47+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
48+
from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
49+
from liger_kernel.ops.geglu import geglu_backward # noqa: F401
50+
from liger_kernel.ops.geglu import geglu_forward # noqa: F401
51+
from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401
52+
from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
53+
from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
54+
from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
55+
from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
56+
from liger_kernel.ops.jsd import jsd_backward # noqa: F401
57+
from liger_kernel.ops.jsd import jsd_forward # noqa: F401
58+
from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401
59+
from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401
60+
from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401
61+
from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401
62+
from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401
63+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401
64+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401
65+
from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401
66+
from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401
67+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401
68+
from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401
69+
from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401
70+
from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401
71+
from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401
72+
from liger_kernel.ops.rope import rope_backward # noqa: F401
73+
from liger_kernel.ops.rope import rope_forward # noqa: F401
74+
from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401
75+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401
76+
from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
77+
from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
78+
from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
79+
from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
80+
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
81+
from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401
82+
83+
# NOTE: __all__ is intentionally NOT defined.
84+
# - Import from this package (liger_kernel.ops) -> subject to vendor replacement
85+
# - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation
86+
87+
88+
# =============================================================================
89+
# Vendor-specific replacement logic
90+
# =============================================================================
91+
92+
93+
def _replace_with_vendor_ops():
94+
"""
95+
Replace/add vendor-specific operator implementations.
96+
97+
This function is called automatically on module load. It:
98+
1. Detects the current device (cuda, npu, xpu, etc.)
99+
2. Looks up the vendor for that device via VENDOR_REGISTRY
100+
3. Loads and applies vendor-specific implementations
101+
102+
Vendor implementations should be placed in:
103+
liger_kernel/ops/backends/_<vendor>/ops/
104+
105+
If the vendor module defines __all__, only those symbols are exported.
106+
Otherwise, all public symbols (not starting with _) are auto-discovered.
107+
108+
Note: Vendor can both override existing ops AND add new vendor-specific ops.
109+
"""
110+
from liger_kernel.ops.backends import get_vendor_for_device
111+
from liger_kernel.utils import infer_device
112+
113+
device = infer_device()
114+
115+
# Look up vendor info for this device
116+
vendor_info = get_vendor_for_device(device)
117+
if vendor_info is None:
118+
return
119+
120+
try:
121+
import importlib
122+
123+
vendor_ops = importlib.import_module(vendor_info.module_path)
124+
125+
# Get names to export: use __all__ if defined, otherwise auto-discover
126+
names_to_export = getattr(vendor_ops, "__all__", None)
127+
128+
if names_to_export is None:
129+
# Auto-discover: find all public symbols (classes and functions)
130+
names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")]
131+
132+
# Replace or add to this module's globals
133+
for name in names_to_export:
134+
globals()[name] = getattr(vendor_ops, name)
135+
136+
except ImportError:
137+
# Vendor module not available, use default implementations
138+
pass
139+
140+
141+
_replace_with_vendor_ops()
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Adding a New Vendor Backend
2+
3+
This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device.
4+
5+
## Concepts
6+
7+
- **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`)
8+
- **Device**: Device type (e.g., `npu`, `xpu`, `cuda`)
9+
- **VendorInfo**: Defines the mapping between vendor and device
10+
11+
## Directory Structure
12+
13+
```
14+
backends/
15+
├── README.md
16+
├── __init__.py
17+
├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY
18+
├── _ascend/ # Ascend (Huawei) vendor - supports NPU
19+
│ ├── __init__.py # Registers VendorInfo for NPU
20+
│ └── ops/
21+
│ ├── __init__.py # Exports vendor-specific implementations
22+
│ └── geglu.py # NPU-specific GEGLU implementation
23+
└── _<vendor>/ # Your new vendor backend
24+
└── ...
25+
```
26+
27+
## How It Works
28+
29+
1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`)
30+
2. Each vendor's `__init__.py` calls `register_vendor()` to register itself
31+
3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called
32+
4. It detects the current device via `infer_device()` and looks up the vendor
33+
5. Vendor implementations replace/add to the `liger_kernel.ops` namespace
34+
35+
## Adding a New Vendor
36+
37+
### Step 1: Create Directory Structure
38+
39+
```bash
40+
mkdir -p backends/_<vendor>/ops
41+
touch backends/_<vendor>/__init__.py
42+
touch backends/_<vendor>/ops/__init__.py
43+
```
44+
45+
### Step 2: Register Your Vendor
46+
47+
In `backends/_<vendor>/__init__.py`, register your vendor:
48+
49+
```python
50+
"""
51+
<Vendor> backend for Liger-Kernel.
52+
"""
53+
54+
from liger_kernel.ops.backends.registry import VendorInfo, register_vendor
55+
56+
register_vendor(
57+
VendorInfo(
58+
vendor="<vendor>",
59+
device="<device>",
60+
)
61+
)
62+
```
63+
64+
65+
### Step 3: Ensure Device Detection Works
66+
67+
Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device:
68+
69+
```python
70+
def infer_device():
71+
if torch.cuda.is_available():
72+
return "cuda"
73+
if is_npu_available():
74+
return "npu"
75+
# Add your device detection here
76+
if is_<device>_available():
77+
return "<device>"
78+
return "cpu"
79+
```
80+
81+
### Step 4: Implement Vendor-Specific Operators
82+
83+
Create operator files in `backends/_<vendor>/ops/`. For example, `geglu.py`:
84+
85+
```python
86+
import torch
87+
88+
class LigerGELUMulFunction(torch.autograd.Function):
89+
"""
90+
Vendor-specific LigerGELUMulFunction implementation.
91+
"""
92+
@staticmethod
93+
def forward(ctx, a, b):
94+
# Your vendor-specific forward implementation
95+
...
96+
97+
@staticmethod
98+
def backward(ctx, dc):
99+
# Your vendor-specific backward implementation
100+
...
101+
102+
# Optional: vendor-specific kernel functions
103+
def geglu_forward_vendor(a, b):
104+
...
105+
106+
def geglu_backward_vendor(a, b, dc):
107+
...
108+
```
109+
110+
### Step 5: Export in `ops/__init__.py`
111+
112+
In `backends/_<vendor>/ops/__init__.py`, export your implementations:
113+
114+
```python
115+
"""
116+
<Vendor>-specific operator implementations.
117+
"""
118+
119+
from .<module> import (
120+
LigerGELUMulFunction,
121+
geglu_forward_vendor as geglu_forward, # Rename to match default API
122+
geglu_backward_vendor as geglu_backward,
123+
)
124+
125+
# Explicitly declare what to export (recommended)
126+
__all__ = [
127+
"LigerGELUMulFunction",
128+
"geglu_forward",
129+
"geglu_backward",
130+
]
131+
```
132+
133+
## Key Points
134+
135+
### Incremental Override
136+
137+
You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation.
138+
139+
### Vendor-Specific Additions
140+
141+
Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import.
142+
143+
### Naming Convention
144+
145+
- Use the **same class/function names** as the default implementations for overrides
146+
- This allows seamless replacement without changing user code
147+
- Use `as` imports to rename if your internal naming differs
148+
149+
## Example: Ascend NPU Backend
150+
151+
See `_ascend/` directory for a complete example of the Ascend NPU backend implementation.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import importlib
2+
import pkgutil
3+
4+
from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401
5+
from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401
6+
from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401
7+
from liger_kernel.ops.backends.registry import register_vendor # noqa: F401
8+
9+
# Auto-import all _<vendor> subpackages to trigger registration
10+
# Each vendor's __init__.py calls register_vendor() when imported
11+
for _, modname, ispkg in pkgutil.iter_modules(__path__):
12+
if ispkg and modname.startswith("_"):
13+
importlib.import_module(f"{__name__}.{modname}")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from liger_kernel.ops.backends.registry import VendorInfo
2+
from liger_kernel.ops.backends.registry import register_vendor
3+
4+
# Register Ascend vendor for NPU device
5+
register_vendor(VendorInfo(vendor="ascend", device="npu"))
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Ascend NPU operator implementations.
3+
4+
This module exports Ascend NPU-optimized implementations that will automatically
5+
replace the default implementations when running on NPU devices.
6+
7+
Both Function classes and kernel functions can be exported here.
8+
9+
To add a new operator:
10+
1. Create the implementation file (e.g., rms_norm.py)
11+
2. Import the Function class and/or kernel functions here
12+
3. Optionally add to __all__ for explicit control
13+
14+
If __all__ is not defined, all public symbols will be auto-discovered.
15+
"""
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Vendor registry for Liger-Kernel multi-backend support.
3+
4+
This module defines VendorInfo and the registry for vendor registration.
5+
Each vendor registers itself by calling register_vendor() in its __init__.py.
6+
"""
7+
8+
from dataclasses import dataclass
9+
from typing import Optional
10+
11+
# Dynamically get backends package path to avoid hardcoding
12+
_BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
13+
14+
15+
@dataclass
16+
class VendorInfo:
17+
"""
18+
Information about a chip vendor and its supported device.
19+
20+
Attributes:
21+
vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
22+
device: Device type this vendor supports (e.g., "npu", "xpu")
23+
"""
24+
25+
vendor: str
26+
device: str
27+
28+
@property
29+
def module_path(self) -> str:
30+
"""Auto-generated module path based on vendor name."""
31+
return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
32+
33+
34+
# Registry mapping device types to their vendor info
35+
# Vendors register themselves via register_vendor()
36+
VENDOR_REGISTRY: dict[str, VendorInfo] = {}
37+
38+
39+
def register_vendor(vendor_info: VendorInfo) -> None:
40+
"""
41+
Register a vendor's info in the global registry.
42+
43+
This should be called in each vendor's __init__.py to register itself.
44+
45+
Args:
46+
vendor_info: VendorInfo instance to register
47+
"""
48+
VENDOR_REGISTRY[vendor_info.device] = vendor_info
49+
50+
51+
def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
52+
"""
53+
Get the VendorInfo for a given device type.
54+
55+
Args:
56+
device: Device type (e.g., "npu", "xpu")
57+
58+
Returns:
59+
VendorInfo if found, None otherwise
60+
"""
61+
return VENDOR_REGISTRY.get(device)

0 commit comments

Comments
 (0)