Skip to content

Commit f20b25f

Browse files
[XPU]Fixed the issue with multiple num_warps parameters being passed in. (#831)
## Summary When testing on the XPU platform, since **num_warps** is passed in again by **kernel_args**, an error ```TypeError: triton.runtime.jit.KernelInterface.__getitem__.<locals>.<lambda>() got multiple values for keyword argument 'num_warps'``` is reported when running ```pytest -rA test/transformers/test_layer_norm.py```. ## Testing Done Run ```pytest -rA test/transformers/test_layer_norm.py```: **Before modification:** ``` def layer_norm_backward(dY, X, W, B, Mean, RSTD): """ Args: dY: Gradient of output X: Input tensor W: Weight tensor B: Bias tensor Mean: Pre-computed mean RSTD: Pre-computed reciprocal standard deviation Returns: Tuple of (input_grad, weight_grad, bias_grad) """ shape = dY.shape dim = shape[-1] dY = dY.view(-1, dim) n_rows, n_cols = dY.shape # Allocate gradient tensors DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation) grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device) DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device) # Calculate optimal block size and warp configuration BLOCK_SIZE, num_warps = calculate_settings(n_cols) if n_cols > BLOCK_SIZE: raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.") # Determine dtype for triton operations triton_dtype = ( tl.float32 if X.dtype == torch.float32 else tl.bfloat16 if X.dtype == torch.bfloat16 else tl.float16 if X.dtype == torch.float16 else tl.float32 # fallback ) # Use float32 for atomic operations if bfloat16 is not supported atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype # XPU-specific optimization kernel_args = {} if X.device.type == "xpu": kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4}) # Launch kernel with one thread block per row for optimal performance grid = (n_rows,) > _layer_norm_backward_kernel[grid]( X, W, Mean, RSTD, DX, DW, DB, dY, X.stride(0), DX.stride(0), dY.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype, atomic_dtype=atomic_dtype, num_warps=num_warps, **kernel_args, ) E TypeError: triton.runtime.jit.KernelInterface.__getitem__.<locals>.<lambda>() got multiple values for keyword argument 'num_warps' ../../src/liger_kernel/ops/layer_norm.py:266: TypeError ``` **After modification:** The test cases passed. **Device Name**: Intel(R) Data Center GPU Max 1550 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence When I run ```pytest -rA test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gemma3_text-32-1e-05-dtype17-0.01-0.01-0.1-0.01-0.01-0.01]``` locally, sometimes it passed, sometimes it failed. It is **not related to this PR**, maybe it is an issue to be solved. Co-authored-by: Shao Tang <[email protected]>
1 parent 1ea9175 commit f20b25f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/liger_kernel/ops/layer_norm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
256256
# Use float32 for atomic operations if bfloat16 is not supported
257257
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
258258

259+
kernel_args = {"num_warps": num_warps}
259260
# XPU-specific optimization
260-
kernel_args = {}
261261
if X.device.type == "xpu":
262262
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
263263

@@ -279,7 +279,6 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
279279
BLOCK_SIZE=BLOCK_SIZE,
280280
dtype=triton_dtype,
281281
atomic_dtype=atomic_dtype,
282-
num_warps=num_warps,
283282
**kernel_args,
284283
)
285284

0 commit comments

Comments
 (0)