@@ -79,6 +79,12 @@ constexpr int kMaxUsersToRender = 16;
7979using OutputEdges =
8080 absl::flat_hash_map<std::string, std::vector<const xla::HloInstruction*>>;
8181
82+ // GroupNodeAttributes is a map from group namespace to a map of attribute key
83+ // to attribute value.
84+ using GroupNodeAttributes =
85+ absl::flat_hash_map<std::string,
86+ absl::flat_hash_map<std::string, std::string>>;
87+
8288// Returns true if value is considered empty.
8389bool IsEmpty (const llvm::json::Value& value) {
8490 if (const auto * obj = value.getAsObject ()) {
@@ -503,14 +509,16 @@ void PopulateOutputsMetadata(
503509// `computation_expand`: decide which computation to expand.
504510//
505511// `output_edges`: record all the edges from the instruction to its users.
512+ //
513+ // `group_node_attributes`: record all the attributes from the group node.
506514absl::Status HloComputationToGraphImpl (
507515 const xla::HloComputation& computation, const NodeFilter& node_filter,
508516 const ComputationExpand& computation_expand,
509517 const HloAdapterOption& options,
510518 absl::flat_hash_set<const xla::HloComputation*>& built_computations,
511519 std::vector<std::string>& computation_stack,
512520 std::vector<GraphNodeBuilder>& instruction_node_builders,
513- OutputEdges& output_edges) {
521+ OutputEdges& output_edges, GroupNodeAttributes& group_node_attributes ) {
514522 if (built_computations.contains (&computation)) {
515523 return absl::OkStatus ();
516524 }
@@ -526,6 +534,12 @@ absl::Status HloComputationToGraphImpl(
526534 output_edges, computation_expand));
527535 builder.SetNodeLabel (GetInstructionId (computation.FusionInstruction ()));
528536 builder.AppendNodeAttribute (kFusionComputation , computation.name ());
537+
538+ // Populate group node attributes from the pinned node.
539+ const std::string& group_name = builder.GetNodeName ();
540+ for (const auto & attr : builder.GetNodeAttributes ()) {
541+ group_node_attributes[group_name][attr.key ] = attr.value ;
542+ }
529543 } else {
530544 // Build the pinned node representing the computation.
531545 builder.SetNodeId (GetComputationId (&computation));
@@ -554,7 +568,7 @@ absl::Status HloComputationToGraphImpl(
554568 RETURN_IF_ERROR (HloComputationToGraphImpl (
555569 *(instruction->fused_instructions_computation ()), node_filter,
556570 computation_expand, options, built_computations, computation_stack,
557- instruction_node_builders, output_edges));
571+ instruction_node_builders, output_edges, group_node_attributes ));
558572 computation_stack.pop_back ();
559573 } else if (IsGetTupleElement (options, instruction)) {
560574 continue ;
@@ -576,7 +590,7 @@ absl::Status HloComputationToGraphImpl(
576590 RETURN_IF_ERROR (HloComputationToGraphImpl (
577591 *subcomputation, node_filter, computation_expand, options,
578592 built_computations, computation_stack, instruction_node_builders,
579- output_edges));
593+ output_edges, group_node_attributes ));
580594 int cur_node_count = instruction_node_builders.size ();
581595 computation_stack.pop_back ();
582596
@@ -633,7 +647,8 @@ absl::StatusOr<GraphCollection> HloToGraph(
633647
634648 RETURN_IF_ERROR (HloComputationToGraphImpl (
635649 computation, node_filter, computation_expand, options, built_computations,
636- computation_stack, instruction_node_builders, output_edges));
650+ computation_stack, instruction_node_builders, output_edges,
651+ graph.subgraphs .back ().group_node_attributes ));
637652
638653 for (GraphNodeBuilder& builder : instruction_node_builders) {
639654 if (const auto & it = output_edges.find (builder.GetNodeId ());
0 commit comments