Skip to content

Commit e6667b5

Browse files
Includes instruction source files in the Auto Sharding request proto.
PiperOrigin-RevId: 666104783
1 parent e990858 commit e6667b5

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

xla/hlo/experimental/auto_sharding/auto_sharding.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,6 +2018,7 @@ AutoShardingSolverResult CallSolver(
20182018
request.add_instruction_names(
20192019
absl::StrCat(instruction_name, " (id: ", node_idx, ")"));
20202020
request.add_opcodes(std::string(opcode));
2021+
request.add_metadata_source_files(instruction->metadata().source_file());
20212022
AutoShardingSolverRequest_Costs ci, di, mi, pi;
20222023
AutoShardingSolverRequest_Names strategy_names;
20232024
std::optional<HloSharding> default_strategy;

xla/hlo/experimental/auto_sharding/auto_sharding.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ message AutoShardingSolverRequest {
6868
repeated Costs value_costs = 15;
6969
repeated string instruction_names = 16;
7070
repeated string opcodes = 33;
71+
repeated string metadata_source_files = 40;
7172
repeated Names strategy_names = 32;
7273
optional SolverTimeout solver_timeout = 17;
7374
optional Coeff overbudget_coeff = 18;

xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ AutoShardingSolverRequest DefaultAutoShardingSolverRequest() {
141141
1, 0, 1,
142142
1, 1, 0}};
143143
const std::vector<std::string> instruction_names = {"A", "B", "C", "D", "E"};
144+
const std::vector<std::string> metadata_source_files = {"attention.py",
145+
"convolution.py",
146+
"layers.py",
147+
"logits.py",
148+
"pipeline.py"};
144149

145150
AutoShardingSolverRequest request;
146151
request.set_num_nodes(5);
@@ -159,6 +164,8 @@ AutoShardingSolverRequest DefaultAutoShardingSolverRequest() {
159164
AddCosts(request.mutable_value_costs(), v);
160165
request.mutable_instruction_names()->Add(instruction_names.begin(),
161166
instruction_names.end());
167+
request.mutable_metadata_source_files()->Add(metadata_source_files.begin(),
168+
metadata_source_files.end());
162169
return request;
163170
}
164171

0 commit comments

Comments
 (0)