Skip to content

Commit 1351df8

Browse files
derdrdirkGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Add documentation to Priority Fusion pass.
PiperOrigin-RevId: 837168939
1 parent 696445c commit 1351df8

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

xla/service/gpu/transforms/priority_fusion.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,45 @@ limitations under the License.
3939
namespace xla {
4040
namespace gpu {
4141

42+
// PriorityFusion is the main fusion pass for XLA:GPU. It is an HLO pass that
43+
// assigns a priority to each producer instruction based on the estimated
44+
// performance benefit of fusing it into its consumers. The benefit is
45+
// calculated using a performance cost model:
46+
//
47+
// priority = time_unfused - time_fused
48+
//
49+
// Note: If fusing a producer into its consumers requires duplicating the
50+
// producer, the cost model accounts for this duplication.
51+
//
52+
// The algorithm can be summarized in the following steps:
53+
// 1. For each producer, call the cost model to estimate the potential benefit
54+
// of fusing it with all its consumers.
55+
// 2. Put all producers with a positive benefit into a priority queue, ordered
56+
// by benefit.
57+
// 3. Pop the producer with the highest priority from the queue.
58+
// 4. Fuse the producer with its consumers. This may result in a new fusion
59+
// instruction, or merging into an existing fusion.
60+
// 5. Update the priorities of the operands of the fused instructions and
61+
// of instructions whose consumers have changed, and update them in the
62+
// priority queue.
63+
// 6. If the queue is not empty, go to step 3.
64+
//
65+
// Example:
66+
// Consider A -> B -> C, where A, B, and C are fusible operations.
67+
// The fusible producers are A and B.
68+
//
69+
// Priorities are computed:
70+
// - P(A) = benefit of fusing A into B.
71+
// - P(B) = benefit of fusing B into C.
72+
//
73+
// Assuming P(A)=10 and P(B)=5, the queue is [(A,10), (B,5)].
74+
// - A is popped and fused into B, creating fusion(A+B).
75+
// - The graph becomes fusion(A+B) -> C.
76+
// - Priority of fusion(A+B) is computed, P(fusion(A+B))=8.
77+
// - The queue becomes [(fusion(A+B),8)].
78+
// - fusion(A+B) is popped and fused into C, creating fusion(A+B+C).
79+
// - The queue becomes empty, and fusion terminates.
80+
//
4281
class PriorityFusion : public HloModulePass {
4382
public:
4483
PriorityFusion(tsl::thread::ThreadPool* thread_pool,

0 commit comments

Comments
 (0)