Skip to content

Commit 8d2943b

Browse files
nputikhinGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Refactor GEMM fusion pass planing loop to use builder
By splitting out the plan building logic we make the loop itself more compact and understandable. PiperOrigin-RevId: 827889906
1 parent 730851c commit 8d2943b

File tree

1 file changed

+111
-74
lines changed

1 file changed

+111
-74
lines changed

xla/service/gpu/transforms/gemm_fusion.cc

Lines changed: 111 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,6 @@ class AdjacencyList {
106106
std::vector<std::vector<NodeId>> adj_;
107107
};
108108

109-
struct HloAndDimOrder {
110-
const HloInstruction* original_hlo = nullptr;
111-
DimensionOrder dim_order;
112-
};
113-
114109
struct HloAndIterSpec {
115110
const HloInstruction* original_hlo;
116111
TensorIterationSpec iter_spec;
@@ -278,6 +273,87 @@ std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
278273
std::get<DotRequirements>(combined_reqs)};
279274
}
280275

276+
class FusionPlanBuilder {
277+
public:
278+
// Builds and returns the FusionPlan. Clears internal state.
279+
FusionPlan BuildPlan() {
280+
FusionPlan fusion_plan;
281+
for (auto& [node_id, node] : node_map_) {
282+
CHECK(node.should_fuse.has_value());
283+
fusion_plan.map[node_id] =
284+
NodeFusionPlan{node.original_hlo, *node.should_fuse};
285+
}
286+
287+
node_map_.clear();
288+
node_reuse_map_.clear();
289+
fusion_plan.graph = std::move(graph_);
290+
return fusion_plan;
291+
}
292+
293+
void ReserveSpaceForOutNeighbors(AdjacencyList::NodeId node_id,
294+
size_t count) {
295+
graph_.ReserveSpaceForOutNeighbors(node_id, count);
296+
}
297+
298+
void AddArc(AdjacencyList::NodeId from, AdjacencyList::NodeId to) {
299+
graph_.AddArc(from, to);
300+
}
301+
302+
const HloInstruction* GetOriginalHlo(AdjacencyList::NodeId node_id) const {
303+
return node_map_.at(node_id).original_hlo;
304+
}
305+
306+
const DimensionOrder& GetDimOrder(AdjacencyList::NodeId node_id) const {
307+
return node_map_.at(node_id).dim_order;
308+
}
309+
310+
// Inserts a node for the given HLO and `dim_order` unless already exists.
311+
// Returns the node id and a bool indicating if a new node was inserted.
312+
std::pair<AdjacencyList::NodeId, bool> InsertNode(
313+
const HloInstruction& hlo, const DimensionOrder& dim_order) {
314+
HloAndIterSpec reuse_key{&hlo, dim_order.ToTensorIterationSpec()};
315+
316+
// Attempt to insert a placeholder. If the key already exists, inserted is
317+
// false.
318+
auto [it, inserted] = node_reuse_map_.insert({reuse_key, -1});
319+
if (!inserted) {
320+
return {it->second, false};
321+
}
322+
323+
// Key was not present. Create the node and update the map.
324+
AdjacencyList::NodeId node_id = graph_.AddNode();
325+
it->second = node_id;
326+
CHECK(node_map_
327+
.insert({node_id,
328+
Node{&hlo, dim_order, /*should_fuse=*/std::nullopt}})
329+
.second);
330+
return {node_id, true};
331+
}
332+
333+
// Assigns fusion decision for the specified node.
334+
// The node must not have an already assigned decision.
335+
void SetShouldFuseNode(AdjacencyList::NodeId node_id, bool should_fuse) {
336+
Node& node = node_map_.at(node_id);
337+
CHECK(!node.should_fuse.has_value());
338+
node.should_fuse = should_fuse;
339+
}
340+
341+
private:
342+
AdjacencyList graph_;
343+
344+
struct Node {
345+
const HloInstruction* original_hlo;
346+
DimensionOrder dim_order;
347+
std::optional<bool> should_fuse;
348+
};
349+
absl::flat_hash_map<AdjacencyList::NodeId, Node> node_map_;
350+
351+
// Allows reusing nodes when multiple instructions iterate over the same HLO
352+
// using the same iteration spec. In that case we don't duplicate the
353+
// instruction in the fusion.
354+
absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map_;
355+
};
356+
281357
// Builds the fusion map and the requirements which can later be used to
282358
// actually fuse that subgraph.
283359
FusionPlanAndRequirements BuildFusionPlanTowardOperands(
@@ -288,61 +364,32 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands(
288364
const DotRequirements& requirements_so_far) {
289365
CHECK(!max_params.has_value() || max_params.value() >= 1);
290366

291-
// The graph describing the structure of the fusion that we build - nodes
292-
// corresponding to the instructions and arcs pointing from users to operands.
293-
// We can build and modify this graph easily without the need to create
294-
// HloInstructions at this point.
295-
AdjacencyList graph;
296-
// Stores the original HLO and the dimension order for each node. This is a
297-
// temporary map which is used when processing the nodes in this function.
298-
absl::flat_hash_map<AdjacencyList::NodeId, HloAndDimOrder>
299-
hlo_and_dim_order_map;
300-
// Stores the information needed to build the fused HLO for each node (what
301-
// was the original HLO and whether we should fuse it or create a parameter).
302-
// This is one of the outputs of this function.
303-
absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> fusion_plan_map;
304-
// Allows reusing nodes when multiple instructions iterate over the same HLO
305-
// using the same iteration spec. In that case we don't duplicate the
306-
// instruction in the fusion.
307-
absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map;
367+
FusionPlanBuilder fusion_builder;
368+
308369
// The requirements imposed by the fusion choices made in this function,
309-
// combined with the existing requirements. This is one of the outputs of this
310-
// function.
370+
// combined with the existing requirements. This is one of the outputs of
371+
// this function.
311372
DotRequirements combined_reqs = requirements_so_far;
312373

313-
auto get_or_create_fusion_node =
314-
[&](const HloInstruction& hlo, const DimensionOrder& dim_order,
315-
bool* is_new_node = nullptr) -> AdjacencyList::NodeId {
316-
HloAndIterSpec reuse_key = {&hlo, dim_order.ToTensorIterationSpec()};
317-
if (auto it = node_reuse_map.find(reuse_key); it != node_reuse_map.end()) {
318-
if (is_new_node != nullptr) {
319-
*is_new_node = false;
320-
}
321-
return it->second;
322-
}
323-
AdjacencyList::NodeId node_id = graph.AddNode();
324-
CHECK(hlo_and_dim_order_map.insert({node_id, {&hlo, dim_order}}).second);
325-
CHECK(node_reuse_map.insert({reuse_key, node_id}).second);
326-
if (is_new_node != nullptr) {
327-
*is_new_node = true;
328-
}
329-
return node_id;
330-
};
331374
AdjacencyList::NodeId root =
332-
get_or_create_fusion_node(root_hlo, root_dim_order);
375+
fusion_builder.InsertNode(root_hlo, root_dim_order).first;
333376

334377
// Nodes at the fusion edge that can either get fused too or become parameters
335378
// of the fusion. Used to track the number of parameters.
336379
absl::flat_hash_set<AdjacencyList::NodeId> inputs({root});
380+
337381
std::queue<AdjacencyList::NodeId> queue({root});
338382
int64_t num_requeued = 0;
383+
339384
// BFS
385+
// If all queued instructions are re-queued, they all exceed the parameter
386+
// limit, so stop fusing.
340387
while (queue.size() > num_requeued) {
341388
AdjacencyList::NodeId node_id = queue.front();
342389
queue.pop();
343-
const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
344-
const HloInstruction& original_hlo = *hlo_and_dim_order.original_hlo;
345-
const DimensionOrder& dim_order = hlo_and_dim_order.dim_order;
390+
const HloInstruction& original_hlo =
391+
*fusion_builder.GetOriginalHlo(node_id);
392+
const DimensionOrder& dim_order = fusion_builder.GetDimOrder(node_id);
346393

347394
// Watch the total number of fusion parameters.
348395
if (max_params.has_value() &&
@@ -355,55 +402,45 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands(
355402
continue;
356403
}
357404
num_requeued = 0;
405+
358406
if (original_hlo.opcode() == HloOpcode::kParameter) {
359-
CHECK(fusion_plan_map
360-
.insert({node_id, {&original_hlo, /*should_fuse=*/false}})
361-
.second);
407+
fusion_builder.SetShouldFuseNode(node_id, false);
362408
continue;
363409
}
410+
364411
auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(
365412
original_hlo, dim_order, properties, gpu_version, combined_reqs);
366413
if (!opt_result.has_value()) {
367-
CHECK(fusion_plan_map
368-
.insert({node_id, {&original_hlo, /*should_fuse=*/false}})
369-
.second);
414+
fusion_builder.SetShouldFuseNode(node_id, false);
370415
continue;
371416
}
417+
372418
const DimOrderMap operand_dim_orders = std::move(opt_result->dim_orders);
373419
combined_reqs = std::move(opt_result->requirements);
420+
374421
inputs.erase(node_id);
375-
graph.ReserveSpaceForOutNeighbors(node_id, original_hlo.operand_count());
376-
for (int64_t i = 0; i < original_hlo.operand_count(); ++i) {
377-
const HloInstruction& operand = *original_hlo.operand(i);
378-
const DimensionOrder& operand_dim_order = operand_dim_orders.at(&operand);
379-
bool is_new_node = false;
380-
AdjacencyList::NodeId operand_node_id =
381-
get_or_create_fusion_node(operand, operand_dim_order, &is_new_node);
382-
graph.AddArc(node_id, operand_node_id);
422+
fusion_builder.ReserveSpaceForOutNeighbors(node_id,
423+
original_hlo.operand_count());
424+
for (const HloInstruction* operand : original_hlo.operands()) {
425+
const DimensionOrder& operand_dim_order = operand_dim_orders.at(operand);
426+
auto [operand_node_id, is_new_node] =
427+
fusion_builder.InsertNode(*operand, operand_dim_order);
428+
fusion_builder.AddArc(node_id, operand_node_id);
383429
if (is_new_node) {
384-
VLOG(6) << "Enqueueing " << operand.ToString() << ":"
430+
VLOG(6) << "Enqueueing " << operand->ToString() << ":"
385431
<< operand_dim_order.ToString();
386432
inputs.insert(operand_node_id);
387433
queue.push(operand_node_id);
388434
}
389435
}
390-
CHECK(
391-
fusion_plan_map.insert({node_id, {&original_hlo, /*should_fuse=*/true}})
392-
.second);
436+
fusion_builder.SetShouldFuseNode(node_id, true);
393437
}
394438
// Handle the remaining requeued items.
395-
while (!queue.empty()) {
439+
for (; !queue.empty(); queue.pop()) {
396440
AdjacencyList::NodeId node_id = queue.front();
397-
queue.pop();
398-
399-
const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
400-
CHECK(fusion_plan_map
401-
.insert({node_id,
402-
{hlo_and_dim_order.original_hlo, /*should_fuse=*/false}})
403-
.second);
441+
fusion_builder.SetShouldFuseNode(node_id, false);
404442
}
405-
return {{std::move(graph), std::move(fusion_plan_map)},
406-
std::move(combined_reqs)};
443+
return {fusion_builder.BuildPlan(), std::move(combined_reqs)};
407444
}
408445

409446
// Builds the HLO instructions for the fusion represented by `fusion_plan`,

0 commit comments

Comments
 (0)