|
| 1 | +# Debugging Pallas |
| 2 | + |
| 3 | +<!--internal:0--> |
| 4 | + |
| 5 | +<!--* |
| 6 | +freshness: { owner: 'justinfu' reviewed: '2024-11-19' } |
| 7 | +*--> |
| 8 | + |
| 9 | +[TOC] |
| 10 | + |
| 11 | +This document contains a collection of tips and tricks for debugging Pallas |
| 12 | +programs. For any specific requests or ideas for improvement, please create |
| 13 | +a ticket on https://github.com/jax-ml/jax/issues. |
| 14 | + |
| 15 | +## Debugging Tools |
| 16 | + |
| 17 | +### Interpret (HLO) Mode |
| 18 | + |
| 19 | +Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. |
| 20 | + |
| 21 | +Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations. |
| 22 | + |
| 23 | +### debug_print |
| 24 | + |
| 25 | +The `pl.debug_print` function can be used to print runtime values inside of a kernel. The implementation is currently limited to scalar values, but we are working on lifting this limitation. |
| 26 | + |
| 27 | +For TPUs only, the kernel must be compiled with the 'xla_tpu_enable_log_recorder' option. |
| 28 | +<!--internal:1--> |
| 29 | + |
| 30 | +```python |
| 31 | +kernel = pl.pallas_call(...) |
| 32 | +compiled_kernel = ( |
| 33 | + jax.jit(kernel) |
| 34 | + .lower(x) |
| 35 | + .compile({'xla_tpu_enable_log_recorder': 'true'}) |
| 36 | + ) |
| 37 | +result = compiled_kernel(x) |
| 38 | +``` |
| 39 | + |
| 40 | +### Runtime Asserts |
| 41 | + |
| 42 | +Checkify can be used to insert runtime asserts, nan checks, out of bounds errors, etc. inside of a kernel. |
| 43 | +Pallas implements two options for assertions: a *hard assert* which will crash the TPU if failed, and a *functionalized assertion* which will simulate a runtime assertion that can be thrown |
| 44 | +as a Python error after the kernel has successfully executed. |
| 45 | + |
| 46 | +#### Hard assertion |
| 47 | + |
| 48 | +Hard assertions can be inserted with `checkify.check` |
| 49 | +and running your program with the `--jax_pallas_enable_runtime_assert` flag. |
| 50 | + |
| 51 | +Your code will look like the following: |
| 52 | + |
| 53 | +```python |
| 54 | +from jax.experimental import checkify |
| 55 | + |
| 56 | +def kernel(...): |
| 57 | + checkify.check(x > y, "Check x > y failed") # Will halt if x <= y |
| 58 | +``` |
| 59 | + |
| 60 | +This will print a relatively lengthy dump which resembles the following: |
| 61 | + |
| 62 | +``` |
| 63 | +E1001 15:22:33.275768 4353 real_program_continuator.cc:1350] 0x0x0_TC0: [Physical location: dldgr4:pe1:1] generic::internal: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x169 (from TensorCoreSequencer:1:0x213): Check x > y failed HLO: main; HLO computation: main.3 |
| 64 | +``` |
| 65 | + |
| 66 | +The benefit of a hard assertion is that it is guaranteed to either pass or |
| 67 | +halt the TPU. The kernel will never proceed past the assertion if it fails. |
| 68 | +However, the downside is that if the assertion fails you will |
| 69 | +likely have to restart the program in order to run any other TPU operations, |
| 70 | +and there is no Python error thrown that can be caught. |
| 71 | + |
| 72 | +#### Functionalized assertion |
| 73 | +Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op like so: |
| 74 | + |
| 75 | +```python |
| 76 | +from jax.experimental import checkify |
| 77 | + |
| 78 | +def kernel(...): |
| 79 | + checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y |
| 80 | + |
| 81 | +kernel = pl.pallas_call(...) |
| 82 | +checkified_kernel = checkify.checkify(kernel, |
| 83 | + errors=checkify.all_checks) |
| 84 | +error, result = checkified_kernel(x) |
| 85 | +error.throw() |
| 86 | +``` |
| 87 | + |
| 88 | +This will throw a Python error if any checks failed, such as if a NaN occurred |
| 89 | +or if an out-of-bounds index was accessed. |
| 90 | + |
| 91 | +The benefit of a functionalized assert is that it will throw Python errors |
| 92 | +that can be caught, and it will not interfere with downstream TPU operations. |
| 93 | +However, it requires the kernel to successfully complete, meaning if your |
| 94 | +error would have caused a TPU crash, the crash would still happen and |
| 95 | +the error would not be thrown. |
| 96 | + |
| 97 | + |
| 98 | +### Dumping Jaxprs |
| 99 | + |
| 100 | +Passing in `debug=True` into `pl.pallas_call` will print out the Jaxpr of the kernel as well as the lowered Mosaic code. |
| 101 | + |
| 102 | +```python |
| 103 | +def kernel(x_ref, y_ref, o_ref): |
| 104 | + o_ref[...] = x_ref[...] + y_ref[...] |
| 105 | + |
| 106 | +x = jnp.ones((8, 128), dtype=jnp.float32) |
| 107 | +pl.pallas_call( |
| 108 | + kernel, |
| 109 | + out_shape=jax.ShapeDTypeStruct((8, 128), jnp.float32) |
| 110 | + debug=True, |
| 111 | + name="my_call", |
| 112 | +)(x, x) |
| 113 | +``` |
| 114 | + |
| 115 | +This will output: |
| 116 | + |
| 117 | +``` |
| 118 | +The kernel jaxpr for the pallas_call my_call for kernel function kernel at ...:1000: |
| 119 | +{ lambda ; a:MemRef<None>{float32[8,128]} b:MemRef<None>{float32[8,128]} c:MemRef<None>{float32[8,128]}. let |
| 120 | + d:f32[8,128] <- a[:,:] |
| 121 | + e:f32[8,128] <- b[:,:] |
| 122 | + f:f32[8,128] = add d e |
| 123 | + c[:,:] <- f |
| 124 | + in () } |
| 125 | +
|
| 126 | +The Mosaic module for the pallas_call my_call for kernel function kernel at ...:1000: |
| 127 | +module { |
| 128 | + func.func @main(%arg0: memref<8x128xf32, #tpu.memory_space<vmem>>, %arg1: memref<8x128xf32, #tpu.memory_space<vmem>>, %arg2: memref<8x128xf32, #tpu.memory_space<vmem>>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} { |
| 129 | + %c0 = arith.constant 0 : index |
| 130 | + %c0_0 = arith.constant 0 : index |
| 131 | + %0 = vector.load %arg0[%c0, %c0_0] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32> |
| 132 | + %c0_1 = arith.constant 0 : index |
| 133 | + %c0_2 = arith.constant 0 : index |
| 134 | + %1 = vector.load %arg1[%c0_1, %c0_2] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32> |
| 135 | + %2 = arith.addf %0, %1 : vector<8x128xf32> |
| 136 | + %c0_3 = arith.constant 0 : index |
| 137 | + %c0_4 = arith.constant 0 : index |
| 138 | + %3 = vector.load %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32> |
| 139 | + vector.store %2, %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space<vmem>>, vector<8x128xf32> |
| 140 | + return |
| 141 | + } |
| 142 | +} |
| 143 | +``` |
| 144 | + |
| 145 | +### Dumping Mosaic Passes |
| 146 | + |
| 147 | +Mosaic is the underlying TPU compiler for Pallas. It can be useful to dump Mosaic if you are running into errors that are originating from the Mosaic compiler to see what code is actually being generated. |
| 148 | + |
| 149 | +Passing the `--xla_mosaic_dump_to=<directory>` argument will dump the output of all intermediate Mosaic passes. The names of the files contain either the parameter `name` passed to the `pallas_call`, or the name of the kernel function. A useful option is to dump to Sponge with `--test_arg=--xla_mosaic_dump_to=sponge` after which you will see all passes under the “Artifacts” tab in sponge. |
| 150 | + |
| 151 | +### Static Verification |
| 152 | + |
| 153 | +The static verification tool can be used to automatically detect race conditions in distributed kernels. |
| 154 | +Because this tool uses formal verification, it is best used for small kernels (<=2 devices). |
| 155 | + |
| 156 | +Verification can be performed by running your kernel with the `--jax_pallas_dump_promela_to=<directory>`, |
| 157 | +which will output a Promela dump file. Afterwards, the dump file can be |
| 158 | +analyzed using the [`spin`](https://spinroot.com) tool. For example, with a dump named `dump.pml`, run: |
| 159 | + |
| 160 | +``` |
| 161 | +spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan |
| 162 | +``` |
| 163 | + |
| 164 | +<!--internal:2--> |
| 165 | + |
| 166 | +## Useful Command line flags |
| 167 | + |
| 168 | +* OOB Checks: `--xla_mosaic_on_device_checks=bounds` |
| 169 | +* Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true` |
| 170 | +<!--internal:3--> |
| 171 | +* Dump Mosaic: `--xla_mosaic_dump_to=<directory>` |
| 172 | +* Enable trace markers in XProf: `--xla_enable_transpose_trace` |
| 173 | + |
| 174 | +## Common Errors |
| 175 | + |
| 176 | +### INTERNAL Mosaic failed to compile TPU Kernel |
| 177 | + |
| 178 | +`INTERNAL Mosaic failed to compile TPU Kernel: Not implemented X` |
| 179 | + |
| 180 | +This error means that you hit an unimplemented case in the underlying Mosaic compiler. |
| 181 | +Our recommended course of action here is to file a ticket if one does not already |
| 182 | +exist for your specific error. |
| 183 | + |
| 184 | +In some cases, your error may be due to an operation which cannot be implemented |
| 185 | +efficiently in the compiler, in which your best course of action is to find a workaround. This |
| 186 | +is most commonly seen in `layout` and `shape_cast` errors. The important tip |
| 187 | +to remember regarding layouts is that the last 2 dimensions of arrays in Pallas |
| 188 | +are physically tiled into registers, so any reshapes, slicing, transposes, etc. |
| 189 | +on the last 2 dimensions may trigger a relayout. |
| 190 | + |
| 191 | + |
| 192 | +### VerificationError |
| 193 | + |
| 194 | +A verification error indicates that Pallas produced invalid code for Mosaic. |
| 195 | + |
| 196 | +This is a bug in Pallas, so please file a bug under https://github.com/jax-ml/jax/issues. |
| 197 | + |
| 198 | +### LoweringError |
| 199 | + |
| 200 | +This is a catch-all error type during Pallas to Mosaic lowering and can have many causes. |
| 201 | +In most cases the error message should hint at what is wrong. |
| 202 | + |
| 203 | +For specific errors: |
| 204 | + |
| 205 | +* `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod |
| 206 | + |
| 207 | + |
0 commit comments