Skip to content

Commit 71ed8ac

Browse files
authored
XPU: Enable new grf_mode settings (#1016)
## Summary After [this](intel/intel-xpu-backend-for-triton#5430) change in `triton-xpu` `grf_mode` api changes. This change will be active in `triton-xpu>=3.6` ## Testing Done I ran Liger-Kernels on new and old triton version on PVC (GPU max 1100). I run tests and relevant benchmarks. - Hardware Type: XPU, Intel PVC (GPU max 1100) - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence I will run all tests and update the PR description
1 parent 2ce9bdd commit 71ed8ac

File tree

5 files changed

+24
-8
lines changed

5 files changed

+24
-8
lines changed

src/liger_kernel/ops/fused_add_rms_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from liger_kernel.ops.utils import calculate_settings
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import ensure_contiguous
11+
from liger_kernel.ops.utils import set_large_grf_mode
1112
from liger_kernel.ops.utils import torch_to_triton_dtype
1213
from liger_kernel.utils import get_npu_multi_processor_count
1314
from liger_kernel.utils import is_npu_available
@@ -247,7 +248,7 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
247248
# XPU-specific optimization
248249
kernel_args = {}
249250
if X.device.type == "xpu":
250-
kernel_args["grf_mode"] = "large"
251+
set_large_grf_mode(kernel_args)
251252

252253
# TODO: add _block_fused_add_rms_norm_forward_kernel
253254
_fused_add_rms_norm_forward_kernel[(n_rows,)](
@@ -307,7 +308,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
307308
# XPU-specific optimization
308309
kernel_args = {}
309310
if S.device.type == "xpu":
310-
kernel_args["grf_mode"] = "large"
311+
set_large_grf_mode(kernel_args)
311312

312313
# TODO: add _block_fused_add_rms_norm_backward_kernel
313314
_fused_add_rms_norm_backward_kernel[grid](

src/liger_kernel/ops/layer_norm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from liger_kernel.ops.utils import calculate_settings
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import ensure_contiguous
11+
from liger_kernel.ops.utils import set_large_grf_mode
1112
from liger_kernel.utils import get_npu_multi_processor_count
1213
from liger_kernel.utils import is_npu_available
1314

@@ -199,7 +200,7 @@ def layer_norm_forward(X, W, B, eps):
199200
# XPU-specific optimization
200201
kernel_args = {}
201202
if X.device.type == "xpu":
202-
kernel_args["grf_mode"] = "large"
203+
set_large_grf_mode(kernel_args)
203204

204205
# Launch kernel with one thread block per row for optimal performance
205206
grid = (n_rows,)
@@ -269,7 +270,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
269270
kernel_args = {"num_warps": num_warps}
270271
# XPU-specific optimization
271272
if X.device.type == "xpu":
272-
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
273+
kernel_args.update({"num_warps": 32, "num_stages": 4})
274+
set_large_grf_mode(kernel_args)
273275

274276
# Launch kernel with one thread block per row for optimal performance
275277
_layer_norm_backward_kernel[grid](

src/liger_kernel/ops/poly_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from liger_kernel.ops.utils import calculate_settings
88
from liger_kernel.ops.utils import compare_version
99
from liger_kernel.ops.utils import ensure_contiguous
10+
from liger_kernel.ops.utils import set_large_grf_mode
1011
from liger_kernel.utils import get_npu_multi_processor_count
1112
from liger_kernel.utils import is_npu_available
1213

@@ -239,7 +240,7 @@ def poly_norm_forward(X, W, B, eps=1e-6):
239240
# XPU-specific optimization
240241
kernel_args = {}
241242
if X.device.type == "xpu":
242-
kernel_args["grf_mode"] = "large"
243+
set_large_grf_mode(kernel_args)
243244

244245
# Launch kernel
245246
_poly_norm_forward_kernel[(n_rows,)](
@@ -310,7 +311,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
310311
# XPU-specific optimization
311312
kernel_args = {}
312313
if X.device.type == "xpu":
313-
kernel_args["grf_mode"] = "large"
314+
set_large_grf_mode(kernel_args)
314315

315316
# Launch backward kernel
316317
_poly_norm_backward_kernel[grid](

src/liger_kernel/ops/rms_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from liger_kernel.ops.utils import calculate_settings
2121
from liger_kernel.ops.utils import compare_version
2222
from liger_kernel.ops.utils import ensure_contiguous
23+
from liger_kernel.ops.utils import set_large_grf_mode
2324
from liger_kernel.ops.utils import torch_to_triton_dtype
2425
from liger_kernel.utils import get_npu_multi_processor_count
2526
from liger_kernel.utils import is_npu_available
@@ -436,7 +437,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
436437
# XPU-specific optimization
437438
kernel_args = {}
438439
if X.device.type == "xpu":
439-
kernel_args["grf_mode"] = "large"
440+
set_large_grf_mode(kernel_args)
440441
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
441442
_rms_norm_forward_kernel[(n_rows,)](
442443
Y,
@@ -516,7 +517,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
516517
# XPU-specific optimization
517518
kernel_args = {}
518519
if X.device.type == "xpu":
519-
kernel_args["grf_mode"] = "large"
520+
set_large_grf_mode(kernel_args)
520521

521522
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
522523
_rms_norm_backward_kernel[grid](

src/liger_kernel/ops/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,14 @@ def get_npu_core_count(default: int = 20) -> int:
139139
return int(props.get("num_vectorcore", default))
140140
except Exception:
141141
return default
142+
143+
144+
def set_large_grf_mode(kernel_args: dict):
145+
"""Set large GRF mode for XPU devices."""
146+
# On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
147+
# triton XPU installed from source will be called `triton`.
148+
if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
149+
kernel_args["grf_mode"] = "256"
150+
else:
151+
# API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
152+
kernel_args["grf_mode"] = "large"

0 commit comments

Comments
 (0)