You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/emitters.md
+16-15Lines changed: 16 additions & 15 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -38,7 +38,7 @@ The code consists of the following big building blocks:
38
38
39
39
## Partitioning
40
40
41
-
See [computation_partitioner.h](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/mlir/computation_partitioner.h).
41
+
See [computation_partitioner.h](https://github.com/openxla/xla/blob/ca62f3e1bc9ea1d808c3a4de0a78bae7453389eb/xla/codegen/emitters/computation_partitioner.h).
42
42
43
43
Non-elementwise HLO instructions cannot always be emitted together. Consider the
44
44
following HLO graph:
@@ -79,7 +79,7 @@ The same is applicable to the following example with `slice` and `pad` of `add`.
79
79
80
80
## Elemental emission
81
81
82
-
See [elemental_hlo_to_mlir.h](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h).
82
+
See [elemental_hlo_to_mlir.h](https://github.com/openxla/xla/blob/ca62f3e1bc9ea1d808c3a4de0a78bae7453389eb/xla/codegen/emitters/elemental_hlo_to_mlir.h).
83
83
84
84
Elemental emission creates loops and math/arith ops for `HloInstructions`. For
85
85
the most part, this is straightforward, but there are some interesting things
@@ -163,7 +163,7 @@ No other uses of the output tensors are allowed.
163
163
164
164
### Loop emitter
165
165
166
-
See [loop_mlir.h](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/loop_mlir.h#L4).
166
+
See [loop.h](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/emitters/loop.h#L4).
167
167
168
168
Let's study the most important passes of the MLIR compilation pipeline using the
See [lower_xla_gpu_to_scf.cc](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc).
266
+
See [lower_xla_gpu_to_scf.cc](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/transforms/lower_xla_gpu_to_scf.cc).
267
267
268
268
`xla_gpu.loop` represents a loop nest with a boundary check inside. If the loop
269
269
inductions variables are out of bounds of the indexing map domain, then this
@@ -297,7 +297,7 @@ iteration is skipped. It means, that the loop is converted to 1 or more nested
297
297
298
298
#### Flatten tensors
299
299
300
-
See [flatten_tensors.cc](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/transforms/flatten_tensors.cc).
300
+
See [flatten_tensors.cc](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/transforms/flatten_tensors.cc).
301
301
302
302
The N-d tensors are projected onto 1D. This will simplify the vectorization and
303
303
the lowering to LLVM because every tensor access now corresponds to how the data
See [vectorize_loads_stores.cc](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc).
332
+
See [vectorize_loads_stores.cc](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc).
333
333
334
334
The pass analyses the indices in the `tensor.extract` and `tensor.insert` ops
335
335
and if they are produced by `xla_gpu.apply_indexing` that accesses the elements
See [optimize_loops.cc](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/transforms/optimize_loops.cc).
373
+
See [optimize_loops.cc](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/transforms/optimize_loops.cc).
374
374
375
375
The loop unrolling finds `scf.for` loops that can be unrolled. In this case, the
376
376
loop over the elements of the vector disappears.
@@ -405,9 +405,9 @@ We cannot use the `memref` lowerings for tensors, since we don't bufferize the
405
405
IR and our ABI is not compatible with the `memref` ABI. Instead, we have a
406
406
custom lowering directly from tensors to `LLVM`.
407
407
408
-
- The lowering of tensors is done in [lower_tensors.cc](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/transforms/lower_tensors.cc). `tensor.extract` is
408
+
- The lowering of tensors is done in [lower_tensors.cc](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/transforms/lower_tensors.cc). `tensor.extract` is
409
409
lowered to `llvm.load`, `tensor.insert` to `llvm.store`, in the obvious way.
410
-
-[propagate_slice_indices](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc) and [merge_pointers_to_same_slice](https://github.com/openxla/xla/blob/852d2d2e4abfc7459f50cc958edb68c82e5f9ffe/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc) together
410
+
-[propagate_slice_indices](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/transforms/propagate_slice_indices.cc) and [merge_pointers_to_same_slice](https://github.com/openxla/xla/blob/cfd16b7f21feff17635c782f4489c0f478178eb9/xla/backends/gpu/codegen/transforms/merge_pointers_to_same_slice.cc) together
411
411
implement a detail of buffer assignment and XLA's ABI: if two tensors share
412
412
the same buffer slice, they are only passed once. These passes deduplicate the
413
413
function arguments.
@@ -490,6 +490,7 @@ coalesced writes to the output.
490
490
### Reproducer
491
491
492
492
In order to see the IR after every pass of the compilation pipeline, one can launch `run_hlo_module` with the `--v=5` flag.
0 commit comments