Skip to content

Commit 344d0d9

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Add readme page for debugging tips.
PiperOrigin-RevId: 698939951
1 parent 26443bb commit 344d0d9

File tree

1 file changed

+207
-0
lines changed

1 file changed

+207
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Debugging Pallas
2+
3+
<!--internal:0-->
4+
5+
<!--*
6+
freshness: { owner: 'justinfu' reviewed: '2024-11-19' }
7+
*-->
8+
9+
[TOC]
10+
11+
This document contains a collection of tips and tricks for debugging Pallas
12+
programs. For any specific requests or ideas for improvement, please create
13+
a ticket on https://github.com/jax-ml/jax/issues.
14+
15+
## Debugging Tools
16+
17+
### Interpret (HLO) Mode
18+
19+
Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas.
20+
21+
Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations.
22+
23+
### debug_print
24+
25+
The `pl.debug_print` function can be used to print runtime values inside of a kernel. The implementation is currently limited to scalar values, but we are working on lifting this limitation.
26+
27+
For TPUs only, the kernel must be compiled with the 'xla_tpu_enable_log_recorder' option.
28+
<!--internal:1-->
29+
30+
```python
31+
kernel = pl.pallas_call(...)
32+
compiled_kernel = (
33+
jax.jit(kernel)
34+
.lower(x)
35+
.compile({'xla_tpu_enable_log_recorder': 'true'})
36+
)
37+
result = compiled_kernel(x)
38+
```
39+
40+
### Runtime Asserts
41+
42+
Checkify can be used to insert runtime asserts, nan checks, out of bounds errors, etc. inside of a kernel.
43+
Pallas implements two options for assertions: a *hard assert* which will crash the TPU if failed, and a *functionalized assertion* which will simulate a runtime assertion that can be thrown
44+
as a Python error after the kernel has successfully executed.
45+
46+
#### Hard assertion
47+
48+
Hard assertions can be inserted with `checkify.check`
49+
and running your program with the `--jax_pallas_enable_runtime_assert` flag.
50+
51+
Your code will look like the following:
52+
53+
```python
54+
from jax.experimental import checkify
55+
56+
def kernel(...):
57+
checkify.check(x > y, "Check x > y failed") # Will halt if x <= y
58+
```
59+
60+
This will print a relatively lengthy dump which resembles the following:
61+
62+
```
63+
E1001 15:22:33.275768 4353 real_program_continuator.cc:1350] 0x0x0_TC0: [Physical location: dldgr4:pe1:1] generic::internal: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x169 (from TensorCoreSequencer:1:0x213): Check x > y failed HLO: main; HLO computation: main.3
64+
```
65+
66+
The benefit of a hard assertion is that it is guaranteed to either pass or
67+
halt the TPU. The kernel will never proceed past the assertion if it fails.
68+
However, the downside is that if the assertion fails you will
69+
likely have to restart the program in order to run any other TPU operations,
70+
and there is no Python error thrown that can be caught.
71+
72+
#### Functionalized assertion
73+
Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op like so:
74+
75+
```python
76+
from jax.experimental import checkify
77+
78+
def kernel(...):
79+
checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y
80+
81+
kernel = pl.pallas_call(...)
82+
checkified_kernel = checkify.checkify(kernel,
83+
errors=checkify.all_checks)
84+
error, result = checkified_kernel(x)
85+
error.throw()
86+
```
87+
88+
This will throw a Python error if any checks failed, such as if a NaN occurred
89+
or if an out-of-bounds index was accessed.
90+
91+
The benefit of a functionalized assert is that it will throw Python errors
92+
that can be caught, and it will not interfere with downstream TPU operations.
93+
However, it requires the kernel to successfully complete, meaning if your
94+
error would have caused a TPU crash, the crash would still happen and
95+
the error would not be thrown.
96+
97+
98+
### Dumping Jaxprs
99+
100+
Passing in `debug=True` into `pl.pallas_call` will print out the Jaxpr of the kernel as well as the lowered Mosaic code.
101+
102+
```python
103+
def kernel(x_ref, y_ref, o_ref):
104+
o_ref[...] = x_ref[...] + y_ref[...]
105+
106+
x = jnp.ones((8, 128), dtype=jnp.float32)
107+
pl.pallas_call(
108+
kernel,
109+
out_shape=jax.ShapeDTypeStruct((8, 128), jnp.float32)
110+
debug=True,
111+
name="my_call",
112+
)(x, x)
113+
```
114+
115+
This will output:
116+
117+
```
118+
The kernel jaxpr for the pallas_call my_call for kernel function kernel at ...:1000:
119+
{ lambda ; a:MemRef<None>{float32[8,128]} b:MemRef<None>{float32[8,128]} c:MemRef<None>{float32[8,128]}. let
120+
d:f32[8,128] <- a[:,:]
121+
e:f32[8,128] <- b[:,:]
122+
f:f32[8,128] = add d e
123+
c[:,:] <- f
124+
in () }
125+
126+
The Mosaic module for the pallas_call my_call for kernel function kernel at ...:1000:
127+
module {
128+
func.func @main(%arg0: memref<8x128xf32, #tpu.memory_space<vmem>>, %arg1: memref<8x128xf32, #tpu.memory_space<vmem>>, %arg2: memref<8x128xf32, #tpu.memory_space<vmem>>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} {
129+
%c0 = arith.constant 0 : index
130+
%c0_0 = arith.constant 0 : index
131+
%0 = vector.load %arg0[%c0, %c0_0] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32>
132+
%c0_1 = arith.constant 0 : index
133+
%c0_2 = arith.constant 0 : index
134+
%1 = vector.load %arg1[%c0_1, %c0_2] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32>
135+
%2 = arith.addf %0, %1 : vector<8x128xf32>
136+
%c0_3 = arith.constant 0 : index
137+
%c0_4 = arith.constant 0 : index
138+
%3 = vector.load %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32>
139+
vector.store %2, %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32>
140+
return
141+
}
142+
}
143+
```
144+
145+
### Dumping Mosaic Passes
146+
147+
Mosaic is the underlying TPU compiler for Pallas. It can be useful to dump Mosaic if you are running into errors that are originating from the Mosaic compiler to see what code is actually being generated.
148+
149+
Passing the `--xla_mosaic_dump_to=<directory>` argument will dump the output of all intermediate Mosaic passes. The names of the files contain either the parameter `name` passed to the `pallas_call`, or the name of the kernel function. A useful option is to dump to Sponge with `--test_arg=--xla_mosaic_dump_to=sponge` after which you will see all passes under the “Artifacts” tab in sponge.
150+
151+
### Static Verification
152+
153+
The static verification tool can be used to automatically detect race conditions in distributed kernels.
154+
Because this tool uses formal verification, it is best used for small kernels (<=2 devices).
155+
156+
Verification can be performed by running your kernel with the `--jax_pallas_dump_promela_to=<directory>`,
157+
which will output a Promela dump file. Afterwards, the dump file can be
158+
analyzed using the [`spin`](https://spinroot.com) tool. For example, with a dump named `dump.pml`, run:
159+
160+
```
161+
spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan
162+
```
163+
164+
<!--internal:2-->
165+
166+
## Useful Command line flags
167+
168+
* OOB Checks: `--xla_mosaic_on_device_checks=bounds`
169+
* Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true`
170+
<!--internal:3-->
171+
* Dump Mosaic: `--xla_mosaic_dump_to=<directory>`
172+
* Enable trace markers in XProf: `--xla_enable_transpose_trace`
173+
174+
## Common Errors
175+
176+
### INTERNAL Mosaic failed to compile TPU Kernel
177+
178+
`INTERNAL Mosaic failed to compile TPU Kernel: Not implemented X`
179+
180+
This error means that you hit an unimplemented case in the underlying Mosaic compiler.
181+
Our recommended course of action here is to file a ticket if one does not already
182+
exist for your specific error.
183+
184+
In some cases, your error may be due to an operation which cannot be implemented
185+
efficiently in the compiler, in which your best course of action is to find a workaround. This
186+
is most commonly seen in `layout` and `shape_cast` errors. The important tip
187+
to remember regarding layouts is that the last 2 dimensions of arrays in Pallas
188+
are physically tiled into registers, so any reshapes, slicing, transposes, etc.
189+
on the last 2 dimensions may trigger a relayout.
190+
191+
192+
### VerificationError
193+
194+
A verification error indicates that Pallas produced invalid code for Mosaic.
195+
196+
This is a bug in Pallas, so please file a bug under https://github.com/jax-ml/jax/issues.
197+
198+
### LoweringError
199+
200+
This is a catch-all error type during Pallas to Mosaic lowering and can have many causes.
201+
In most cases the error message should hint at what is wrong.
202+
203+
For specific errors:
204+
205+
* `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod
206+
207+

0 commit comments

Comments
 (0)