Commit f20b25f
[XPU]Fixed the issue with multiple num_warps parameters being passed in. (#831)
## Summary
When testing on the XPU platform, since **num_warps** is passed in again
by **kernel_args**, an error ```TypeError:
triton.runtime.jit.KernelInterface.__getitem__.<locals>.<lambda>() got
multiple values for keyword argument 'num_warps'``` is reported when
running ```pytest -rA test/transformers/test_layer_norm.py```.
## Testing Done
Run ```pytest -rA test/transformers/test_layer_norm.py```:
**Before modification:**
```
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
"""
Args:
dY: Gradient of output
X: Input tensor
W: Weight tensor
B: Bias tensor
Mean: Pre-computed mean
RSTD: Pre-computed reciprocal standard deviation
Returns:
Tuple of (input_grad, weight_grad, bias_grad)
"""
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
# Allocate gradient tensors
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
# Calculate optimal block size and warp configuration
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
if n_cols > BLOCK_SIZE:
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
# Determine dtype for triton operations
triton_dtype = (
tl.float32
if X.dtype == torch.float32
else tl.bfloat16
if X.dtype == torch.bfloat16
else tl.float16
if X.dtype == torch.float16
else tl.float32 # fallback
)
# Use float32 for atomic operations if bfloat16 is not supported
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
# Launch kernel with one thread block per row for optimal performance
grid = (n_rows,)
> _layer_norm_backward_kernel[grid](
X,
W,
Mean,
RSTD,
DX,
DW,
DB,
dY,
X.stride(0),
DX.stride(0),
dY.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
dtype=triton_dtype,
atomic_dtype=atomic_dtype,
num_warps=num_warps,
**kernel_args,
)
E TypeError: triton.runtime.jit.KernelInterface.__getitem__.<locals>.<lambda>() got multiple values for keyword argument 'num_warps'
../../src/liger_kernel/ops/layer_norm.py:266: TypeError
```
**After modification:**
The test cases passed.
**Device Name**: Intel(R) Data Center GPU Max 1550
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
When I run ```pytest -rA
test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gemma3_text-32-1e-05-dtype17-0.01-0.01-0.1-0.01-0.01-0.01]```
locally, sometimes it passed, sometimes it failed. It is **not related
to this PR**, maybe it is an issue to be solved.
Co-authored-by: Shao Tang <[email protected]>1 parent 1ea9175 commit f20b25f
1 file changed
+1
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
256 | 256 | | |
257 | 257 | | |
258 | 258 | | |
| 259 | + | |
259 | 260 | | |
260 | | - | |
261 | 261 | | |
262 | 262 | | |
263 | 263 | | |
| |||
279 | 279 | | |
280 | 280 | | |
281 | 281 | | |
282 | | - | |
283 | 282 | | |
284 | 283 | | |
285 | 284 | | |
| |||
0 commit comments