@@ -39,6 +39,45 @@ limitations under the License.
3939namespace xla {
4040namespace gpu {
4141
42+ // PriorityFusion is the main fusion pass for XLA:GPU. It is an HLO pass that
43+ // assigns a priority to each producer instruction based on the estimated
44+ // performance benefit of fusing it into its consumers. The benefit is
45+ // calculated using a performance cost model:
46+ //
47+ // priority = time_unfused - time_fused
48+ //
49+ // Note: If fusing a producer into its consumers requires duplicating the
50+ // producer, the cost model accounts for this duplication.
51+ //
52+ // The algorithm can be summarized in the following steps:
53+ // 1. For each producer, call the cost model to estimate the potential benefit
54+ // of fusing it with all its consumers.
55+ // 2. Put all producers with a positive benefit into a priority queue, ordered
56+ // by benefit.
57+ // 3. Pop the producer with the highest priority from the queue.
58+ // 4. Fuse the producer with its consumers. This may result in a new fusion
59+ // instruction, or merging into an existing fusion.
60+ // 5. Update the priorities of the operands of the fused instructions and
61+ // of instructions whose consumers have changed, and update them in the
62+ // priority queue.
63+ // 6. If the queue is not empty, go to step 3.
64+ //
65+ // Example:
66+ // Consider A -> B -> C, where A, B, and C are fusible operations.
67+ // The fusible producers are A and B.
68+ //
69+ // Priorities are computed:
70+ // - P(A) = benefit of fusing A into B.
71+ // - P(B) = benefit of fusing B into C.
72+ //
73+ // Assuming P(A)=10 and P(B)=5, the queue is [(A,10), (B,5)].
74+ // - A is popped and fused into B, creating fusion(A+B).
75+ // - The graph becomes fusion(A+B) -> C.
76+ // - Priority of fusion(A+B) is computed, P(fusion(A+B))=8.
77+ // - The queue becomes [(fusion(A+B),8)].
78+ // - fusion(A+B) is popped and fused into C, creating fusion(A+B+C).
79+ // - The queue becomes empty, and fusion terminates.
80+ //
4281class PriorityFusion : public HloModulePass {
4382 public:
4483 PriorityFusion (tsl::thread::ThreadPool* thread_pool,
0 commit comments