Skip to content

Commit 49d8c56

Browse files
yijie-yangcopybara-github
authored andcommitted
Collect attributes from pinned nodes for group nodes in HLO to JSON conversion
PiperOrigin-RevId: 853389382
1 parent 8105b05 commit 49d8c56

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

src/builtin-adapter/graphnode_builder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <string>
2020
#include <utility>
21+
#include <vector>
2122

2223
#include "absl/strings/string_view.h"
2324
#include "formats/schema_structs.h"
@@ -68,6 +69,11 @@ class GraphNodeBuilder {
6869
// Get the attribute value for the given key.
6970
absl::string_view GetNodeAttribute(absl::string_view key);
7071

72+
// Returns all the attributes.
73+
const std::vector<Attribute>& GetNodeAttributes() const {
74+
return node_.node_attrs;
75+
}
76+
7177
// Appends the attribute to the input or output metadata list. If the metadata
7278
// already exists, we append the attribute to that metadata. If it doesn't
7379
// exist, we create a new metadata and add it to the list, then append the

src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ constexpr int kMaxUsersToRender = 16;
7979
using OutputEdges =
8080
absl::flat_hash_map<std::string, std::vector<const xla::HloInstruction*>>;
8181

82+
// GroupNodeAttributes is a map from group namespace to a map of attribute key
83+
// to attribute value.
84+
using GroupNodeAttributes =
85+
absl::flat_hash_map<std::string,
86+
absl::flat_hash_map<std::string, std::string>>;
87+
8288
// Returns true if value is considered empty.
8389
bool IsEmpty(const llvm::json::Value& value) {
8490
if (const auto* obj = value.getAsObject()) {
@@ -503,14 +509,16 @@ void PopulateOutputsMetadata(
503509
// `computation_expand`: decide which computation to expand.
504510
//
505511
// `output_edges`: record all the edges from the instruction to its users.
512+
//
513+
// `group_node_attributes`: record all the attributes from the group node.
506514
absl::Status HloComputationToGraphImpl(
507515
const xla::HloComputation& computation, const NodeFilter& node_filter,
508516
const ComputationExpand& computation_expand,
509517
const HloAdapterOption& options,
510518
absl::flat_hash_set<const xla::HloComputation*>& built_computations,
511519
std::vector<std::string>& computation_stack,
512520
std::vector<GraphNodeBuilder>& instruction_node_builders,
513-
OutputEdges& output_edges) {
521+
OutputEdges& output_edges, GroupNodeAttributes& group_node_attributes) {
514522
if (built_computations.contains(&computation)) {
515523
return absl::OkStatus();
516524
}
@@ -526,6 +534,12 @@ absl::Status HloComputationToGraphImpl(
526534
output_edges, computation_expand));
527535
builder.SetNodeLabel(GetInstructionId(computation.FusionInstruction()));
528536
builder.AppendNodeAttribute(kFusionComputation, computation.name());
537+
538+
// Populate group node attributes from the pinned node.
539+
const std::string& group_name = builder.GetNodeName();
540+
for (const auto& attr : builder.GetNodeAttributes()) {
541+
group_node_attributes[group_name][attr.key] = attr.value;
542+
}
529543
} else {
530544
// Build the pinned node representing the computation.
531545
builder.SetNodeId(GetComputationId(&computation));
@@ -554,7 +568,7 @@ absl::Status HloComputationToGraphImpl(
554568
RETURN_IF_ERROR(HloComputationToGraphImpl(
555569
*(instruction->fused_instructions_computation()), node_filter,
556570
computation_expand, options, built_computations, computation_stack,
557-
instruction_node_builders, output_edges));
571+
instruction_node_builders, output_edges, group_node_attributes));
558572
computation_stack.pop_back();
559573
} else if (IsGetTupleElement(options, instruction)) {
560574
continue;
@@ -576,7 +590,7 @@ absl::Status HloComputationToGraphImpl(
576590
RETURN_IF_ERROR(HloComputationToGraphImpl(
577591
*subcomputation, node_filter, computation_expand, options,
578592
built_computations, computation_stack, instruction_node_builders,
579-
output_edges));
593+
output_edges, group_node_attributes));
580594
int cur_node_count = instruction_node_builders.size();
581595
computation_stack.pop_back();
582596

@@ -633,7 +647,8 @@ absl::StatusOr<GraphCollection> HloToGraph(
633647

634648
RETURN_IF_ERROR(HloComputationToGraphImpl(
635649
computation, node_filter, computation_expand, options, built_computations,
636-
computation_stack, instruction_node_builders, output_edges));
650+
computation_stack, instruction_node_builders, output_edges,
651+
graph.subgraphs.back().group_node_attributes));
637652

638653
for (GraphNodeBuilder& builder : instruction_node_builders) {
639654
if (const auto& it = output_edges.find(builder.GetNodeId());

0 commit comments

Comments
 (0)