@@ -106,11 +106,6 @@ class AdjacencyList {
106106 std::vector<std::vector<NodeId>> adj_;
107107};
108108
109- struct HloAndDimOrder {
110- const HloInstruction* original_hlo = nullptr ;
111- DimensionOrder dim_order;
112- };
113-
114109struct HloAndIterSpec {
115110 const HloInstruction* original_hlo;
116111 TensorIterationSpec iter_spec;
@@ -278,6 +273,87 @@ std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
278273 std::get<DotRequirements>(combined_reqs)};
279274}
280275
276+ class FusionPlanBuilder {
277+ public:
278+ // Builds and returns the FusionPlan. Clears internal state.
279+ FusionPlan BuildPlan () {
280+ FusionPlan fusion_plan;
281+ for (auto & [node_id, node] : node_map_) {
282+ CHECK (node.should_fuse .has_value ());
283+ fusion_plan.map [node_id] =
284+ NodeFusionPlan{node.original_hlo , *node.should_fuse };
285+ }
286+
287+ node_map_.clear ();
288+ node_reuse_map_.clear ();
289+ fusion_plan.graph = std::move (graph_);
290+ return fusion_plan;
291+ }
292+
293+ void ReserveSpaceForOutNeighbors (AdjacencyList::NodeId node_id,
294+ size_t count) {
295+ graph_.ReserveSpaceForOutNeighbors (node_id, count);
296+ }
297+
298+ void AddArc (AdjacencyList::NodeId from, AdjacencyList::NodeId to) {
299+ graph_.AddArc (from, to);
300+ }
301+
302+ const HloInstruction* GetOriginalHlo (AdjacencyList::NodeId node_id) const {
303+ return node_map_.at (node_id).original_hlo ;
304+ }
305+
306+ const DimensionOrder& GetDimOrder (AdjacencyList::NodeId node_id) const {
307+ return node_map_.at (node_id).dim_order ;
308+ }
309+
310+ // Inserts a node for the given HLO and `dim_order` unless already exists.
311+ // Returns the node id and a bool indicating if a new node was inserted.
312+ std::pair<AdjacencyList::NodeId, bool > InsertNode (
313+ const HloInstruction& hlo, const DimensionOrder& dim_order) {
314+ HloAndIterSpec reuse_key{&hlo, dim_order.ToTensorIterationSpec ()};
315+
316+ // Attempt to insert a placeholder. If the key already exists, inserted is
317+ // false.
318+ auto [it, inserted] = node_reuse_map_.insert ({reuse_key, -1 });
319+ if (!inserted) {
320+ return {it->second , false };
321+ }
322+
323+ // Key was not present. Create the node and update the map.
324+ AdjacencyList::NodeId node_id = graph_.AddNode ();
325+ it->second = node_id;
326+ CHECK (node_map_
327+ .insert ({node_id,
328+ Node{&hlo, dim_order, /* should_fuse=*/ std::nullopt }})
329+ .second );
330+ return {node_id, true };
331+ }
332+
333+ // Assigns fusion decision for the specified node.
334+ // The node must not have an already assigned decision.
335+ void SetShouldFuseNode (AdjacencyList::NodeId node_id, bool should_fuse) {
336+ Node& node = node_map_.at (node_id);
337+ CHECK (!node.should_fuse .has_value ());
338+ node.should_fuse = should_fuse;
339+ }
340+
341+ private:
342+ AdjacencyList graph_;
343+
344+ struct Node {
345+ const HloInstruction* original_hlo;
346+ DimensionOrder dim_order;
347+ std::optional<bool > should_fuse;
348+ };
349+ absl::flat_hash_map<AdjacencyList::NodeId, Node> node_map_;
350+
351+ // Allows reusing nodes when multiple instructions iterate over the same HLO
352+ // using the same iteration spec. In that case we don't duplicate the
353+ // instruction in the fusion.
354+ absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map_;
355+ };
356+
281357// Builds the fusion map and the requirements which can later be used to
282358// actually fuse that subgraph.
283359FusionPlanAndRequirements BuildFusionPlanTowardOperands (
@@ -288,61 +364,32 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands(
288364 const DotRequirements& requirements_so_far) {
289365 CHECK (!max_params.has_value () || max_params.value () >= 1 );
290366
291- // The graph describing the structure of the fusion that we build - nodes
292- // corresponding to the instructions and arcs pointing from users to operands.
293- // We can build and modify this graph easily without the need to create
294- // HloInstructions at this point.
295- AdjacencyList graph;
296- // Stores the original HLO and the dimension order for each node. This is a
297- // temporary map which is used when processing the nodes in this function.
298- absl::flat_hash_map<AdjacencyList::NodeId, HloAndDimOrder>
299- hlo_and_dim_order_map;
300- // Stores the information needed to build the fused HLO for each node (what
301- // was the original HLO and whether we should fuse it or create a parameter).
302- // This is one of the outputs of this function.
303- absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> fusion_plan_map;
304- // Allows reusing nodes when multiple instructions iterate over the same HLO
305- // using the same iteration spec. In that case we don't duplicate the
306- // instruction in the fusion.
307- absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map;
367+ FusionPlanBuilder fusion_builder;
368+
308369 // The requirements imposed by the fusion choices made in this function,
309- // combined with the existing requirements. This is one of the outputs of this
310- // function.
370+ // combined with the existing requirements. This is one of the outputs of
371+ // this function.
311372 DotRequirements combined_reqs = requirements_so_far;
312373
313- auto get_or_create_fusion_node =
314- [&](const HloInstruction& hlo, const DimensionOrder& dim_order,
315- bool * is_new_node = nullptr ) -> AdjacencyList::NodeId {
316- HloAndIterSpec reuse_key = {&hlo, dim_order.ToTensorIterationSpec ()};
317- if (auto it = node_reuse_map.find (reuse_key); it != node_reuse_map.end ()) {
318- if (is_new_node != nullptr ) {
319- *is_new_node = false ;
320- }
321- return it->second ;
322- }
323- AdjacencyList::NodeId node_id = graph.AddNode ();
324- CHECK (hlo_and_dim_order_map.insert ({node_id, {&hlo, dim_order}}).second );
325- CHECK (node_reuse_map.insert ({reuse_key, node_id}).second );
326- if (is_new_node != nullptr ) {
327- *is_new_node = true ;
328- }
329- return node_id;
330- };
331374 AdjacencyList::NodeId root =
332- get_or_create_fusion_node (root_hlo, root_dim_order);
375+ fusion_builder. InsertNode (root_hlo, root_dim_order). first ;
333376
334377 // Nodes at the fusion edge that can either get fused too or become parameters
335378 // of the fusion. Used to track the number of parameters.
336379 absl::flat_hash_set<AdjacencyList::NodeId> inputs ({root});
380+
337381 std::queue<AdjacencyList::NodeId> queue ({root});
338382 int64_t num_requeued = 0 ;
383+
339384 // BFS
385+ // If all queued instructions are re-queued, they all exceed the parameter
386+ // limit, so stop fusing.
340387 while (queue.size () > num_requeued) {
341388 AdjacencyList::NodeId node_id = queue.front ();
342389 queue.pop ();
343- const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map. at (node_id);
344- const HloInstruction& original_hlo = *hlo_and_dim_order. original_hlo ;
345- const DimensionOrder& dim_order = hlo_and_dim_order. dim_order ;
390+ const HloInstruction& original_hlo =
391+ *fusion_builder. GetOriginalHlo (node_id) ;
392+ const DimensionOrder& dim_order = fusion_builder. GetDimOrder (node_id) ;
346393
347394 // Watch the total number of fusion parameters.
348395 if (max_params.has_value () &&
@@ -355,55 +402,45 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands(
355402 continue ;
356403 }
357404 num_requeued = 0 ;
405+
358406 if (original_hlo.opcode () == HloOpcode::kParameter ) {
359- CHECK (fusion_plan_map
360- .insert ({node_id, {&original_hlo, /* should_fuse=*/ false }})
361- .second );
407+ fusion_builder.SetShouldFuseNode (node_id, false );
362408 continue ;
363409 }
410+
364411 auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable (
365412 original_hlo, dim_order, properties, gpu_version, combined_reqs);
366413 if (!opt_result.has_value ()) {
367- CHECK (fusion_plan_map
368- .insert ({node_id, {&original_hlo, /* should_fuse=*/ false }})
369- .second );
414+ fusion_builder.SetShouldFuseNode (node_id, false );
370415 continue ;
371416 }
417+
372418 const DimOrderMap operand_dim_orders = std::move (opt_result->dim_orders );
373419 combined_reqs = std::move (opt_result->requirements );
420+
374421 inputs.erase (node_id);
375- graph.ReserveSpaceForOutNeighbors (node_id, original_hlo.operand_count ());
376- for (int64_t i = 0 ; i < original_hlo.operand_count (); ++i) {
377- const HloInstruction& operand = *original_hlo.operand (i);
378- const DimensionOrder& operand_dim_order = operand_dim_orders.at (&operand);
379- bool is_new_node = false ;
380- AdjacencyList::NodeId operand_node_id =
381- get_or_create_fusion_node (operand, operand_dim_order, &is_new_node);
382- graph.AddArc (node_id, operand_node_id);
422+ fusion_builder.ReserveSpaceForOutNeighbors (node_id,
423+ original_hlo.operand_count ());
424+ for (const HloInstruction* operand : original_hlo.operands ()) {
425+ const DimensionOrder& operand_dim_order = operand_dim_orders.at (operand);
426+ auto [operand_node_id, is_new_node] =
427+ fusion_builder.InsertNode (*operand, operand_dim_order);
428+ fusion_builder.AddArc (node_id, operand_node_id);
383429 if (is_new_node) {
384- VLOG (6 ) << " Enqueueing " << operand. ToString () << " :"
430+ VLOG (6 ) << " Enqueueing " << operand-> ToString () << " :"
385431 << operand_dim_order.ToString ();
386432 inputs.insert (operand_node_id);
387433 queue.push (operand_node_id);
388434 }
389435 }
390- CHECK (
391- fusion_plan_map.insert ({node_id, {&original_hlo, /* should_fuse=*/ true }})
392- .second );
436+ fusion_builder.SetShouldFuseNode (node_id, true );
393437 }
394438 // Handle the remaining requeued items.
395- while ( !queue.empty ()) {
439+ for (; !queue.empty (); queue. pop ()) {
396440 AdjacencyList::NodeId node_id = queue.front ();
397- queue.pop ();
398-
399- const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at (node_id);
400- CHECK (fusion_plan_map
401- .insert ({node_id,
402- {hlo_and_dim_order.original_hlo , /* should_fuse=*/ false }})
403- .second );
441+ fusion_builder.SetShouldFuseNode (node_id, false );
404442 }
405- return {{std::move (graph), std::move (fusion_plan_map)},
406- std::move (combined_reqs)};
443+ return {fusion_builder.BuildPlan (), std::move (combined_reqs)};
407444}
408445
409446// Builds the HLO instructions for the fusion represented by `fusion_plan`,
0 commit comments