|
| 1 | +""" |
| 2 | +Liger-Kernel operators with automatic vendor-specific replacement. |
| 3 | +
|
| 4 | +This module provides two ways to import operators: |
| 5 | +
|
| 6 | +1. Import from this package (recommended for Function classes): |
| 7 | + from liger_kernel.ops import LigerGELUMulFunction |
| 8 | +
|
| 9 | + This automatically uses vendor-specific implementation if available. |
| 10 | +
|
| 11 | +2. Import from submodules (for kernel functions or specific access): |
| 12 | + from liger_kernel.ops.geglu import geglu_forward, geglu_backward |
| 13 | +
|
| 14 | + This always uses the default implementation (no auto-replacement). |
| 15 | +
|
| 16 | +The replacement mechanism: |
| 17 | +1. Default implementations are imported from individual modules (e.g., geglu.py) |
| 18 | +2. On module load, device is detected via infer_device() |
| 19 | +3. If running on a supported vendor device (npu, xpu, etc.), the default |
| 20 | + implementations are replaced with vendor-specific ones |
| 21 | +4. All subsequent imports from this package get the replaced versions |
| 22 | +
|
| 23 | +Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...) |
| 24 | + are NOT affected by the replacement mechanism. |
| 25 | +""" |
| 26 | + |
| 27 | +# ============================================================================= |
| 28 | +# Import default implementations |
| 29 | +# Both Function classes and kernel functions are imported here. |
| 30 | +# All of these can be replaced by vendor-specific implementations. |
| 31 | +# ============================================================================= |
| 32 | + |
| 33 | +from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401 |
| 34 | +from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401 |
| 35 | +from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401 |
| 36 | +from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401 |
| 37 | +from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401 |
| 38 | +from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401 |
| 39 | +from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401 |
| 40 | +from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401 |
| 41 | +from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401 |
| 42 | +from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401 |
| 43 | +from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401 |
| 44 | +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401 |
| 45 | +from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401 |
| 46 | +from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401 |
| 47 | +from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401 |
| 48 | +from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401 |
| 49 | +from liger_kernel.ops.geglu import geglu_backward # noqa: F401 |
| 50 | +from liger_kernel.ops.geglu import geglu_forward # noqa: F401 |
| 51 | +from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401 |
| 52 | +from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401 |
| 53 | +from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401 |
| 54 | +from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401 |
| 55 | +from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401 |
| 56 | +from liger_kernel.ops.jsd import jsd_backward # noqa: F401 |
| 57 | +from liger_kernel.ops.jsd import jsd_forward # noqa: F401 |
| 58 | +from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401 |
| 59 | +from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401 |
| 60 | +from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401 |
| 61 | +from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401 |
| 62 | +from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401 |
| 63 | +from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401 |
| 64 | +from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401 |
| 65 | +from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401 |
| 66 | +from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401 |
| 67 | +from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401 |
| 68 | +from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401 |
| 69 | +from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401 |
| 70 | +from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401 |
| 71 | +from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401 |
| 72 | +from liger_kernel.ops.rope import rope_backward # noqa: F401 |
| 73 | +from liger_kernel.ops.rope import rope_forward # noqa: F401 |
| 74 | +from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401 |
| 75 | +from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401 |
| 76 | +from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401 |
| 77 | +from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401 |
| 78 | +from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401 |
| 79 | +from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401 |
| 80 | +from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401 |
| 81 | +from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401 |
| 82 | + |
| 83 | +# NOTE: __all__ is intentionally NOT defined. |
| 84 | +# - Import from this package (liger_kernel.ops) -> subject to vendor replacement |
| 85 | +# - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation |
| 86 | + |
| 87 | + |
| 88 | +# ============================================================================= |
| 89 | +# Vendor-specific replacement logic |
| 90 | +# ============================================================================= |
| 91 | + |
| 92 | + |
| 93 | +def _replace_with_vendor_ops(): |
| 94 | + """ |
| 95 | + Replace/add vendor-specific operator implementations. |
| 96 | +
|
| 97 | + This function is called automatically on module load. It: |
| 98 | + 1. Detects the current device (cuda, npu, xpu, etc.) |
| 99 | + 2. Looks up the vendor for that device via VENDOR_REGISTRY |
| 100 | + 3. Loads and applies vendor-specific implementations |
| 101 | +
|
| 102 | + Vendor implementations should be placed in: |
| 103 | + liger_kernel/ops/backends/_<vendor>/ops/ |
| 104 | +
|
| 105 | + If the vendor module defines __all__, only those symbols are exported. |
| 106 | + Otherwise, all public symbols (not starting with _) are auto-discovered. |
| 107 | +
|
| 108 | + Note: Vendor can both override existing ops AND add new vendor-specific ops. |
| 109 | + """ |
| 110 | + from liger_kernel.ops.backends import get_vendor_for_device |
| 111 | + from liger_kernel.utils import infer_device |
| 112 | + |
| 113 | + device = infer_device() |
| 114 | + |
| 115 | + # Look up vendor info for this device |
| 116 | + vendor_info = get_vendor_for_device(device) |
| 117 | + if vendor_info is None: |
| 118 | + return |
| 119 | + |
| 120 | + try: |
| 121 | + import importlib |
| 122 | + |
| 123 | + vendor_ops = importlib.import_module(vendor_info.module_path) |
| 124 | + |
| 125 | + # Get names to export: use __all__ if defined, otherwise auto-discover |
| 126 | + names_to_export = getattr(vendor_ops, "__all__", None) |
| 127 | + |
| 128 | + if names_to_export is None: |
| 129 | + # Auto-discover: find all public symbols (classes and functions) |
| 130 | + names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")] |
| 131 | + |
| 132 | + # Replace or add to this module's globals |
| 133 | + for name in names_to_export: |
| 134 | + globals()[name] = getattr(vendor_ops, name) |
| 135 | + |
| 136 | + except ImportError: |
| 137 | + # Vendor module not available, use default implementations |
| 138 | + pass |
| 139 | + |
| 140 | + |
| 141 | +_replace_with_vendor_ops() |
0 commit comments