Skip to content

Commit a56ec8d

Browse files
derdrdirkGoogle-ML-Automation
authored andcommitted
[Autotuner] Only test XLAComputation in sharded autotuner test.
PiperOrigin-RevId: 853179571
1 parent 97a3a5a commit a56ec8d

File tree

1 file changed

+20
-42
lines changed

1 file changed

+20
-42
lines changed

xla/pjrt/gpu/se_gpu_pjrt_client_test.cc

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3233,15 +3233,13 @@ absl::Status SuccessfulCrossHostTransferTestBody(bool is_sender,
32333233
}
32343234

32353235
struct ShardedAutotuningTestInfo {
3236-
bool use_xla_computation;
32373236
int num_active_nodes;
32383237
int num_nodes_using_cache;
32393238

32403239
static std::string Name(
32413240
const ::testing::TestParamInfo<ShardedAutotuningTestInfo>& info) {
3242-
return absl::StrFormat(
3243-
"computation_%d_active_%d_cache_%d", info.param.use_xla_computation,
3244-
info.param.num_active_nodes, info.param.num_nodes_using_cache);
3241+
return absl::StrFormat("active_%d_cache_%d", info.param.num_active_nodes,
3242+
info.param.num_nodes_using_cache);
32453243
}
32463244
};
32473245

@@ -3273,8 +3271,6 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
32733271
argv.push_back("sharded_autotuning_test");
32743272
argv.push_back("--test_to_run=ShardedAutotuningWorksHelper");
32753273
argv.push_back(absl::StrFormat("--node_id=%d", node_id));
3276-
argv.push_back(absl::StrFormat("--use_xla_computation=%d",
3277-
param.use_xla_computation));
32783274
argv.push_back(
32793275
absl::StrFormat("--num_active_nodes=%d", param.num_active_nodes));
32803276
argv.push_back(absl::StrFormat("--num_nodes_using_cache=%d",
@@ -3324,8 +3320,7 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
33243320
absl::Status ShardedAutotuningWorksTestBody(const int node_id,
33253321
const int num_active_nodes,
33263322
const int num_nodes_using_cache,
3327-
absl::string_view cache_dir,
3328-
bool use_xla_computation) {
3323+
absl::string_view cache_dir) {
33293324
std::unique_ptr<xla::DistributedRuntimeService> service;
33303325
if (node_id == 0) {
33313326
TF_ASSIGN_OR_RETURN(
@@ -3377,31 +3372,21 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
33773372
debug_options.set_xla_gpu_per_fusion_autotune_cache_dir(cache_dir);
33783373
}
33793374

3380-
mlir::MLIRContext context;
3381-
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
3382-
ParseMlirModuleString(R"mlir(
3383-
func.func public @main(%arg0: tensor<2x32x32xf16>) ->
3384-
(tensor<2x32x32xf16> {jax.result_info = ""}) {
3385-
%0 = stablehlo.dot_general %arg0, %arg0, batching_dims = [0] x [0],
3386-
contracting_dims = [2] x [1]
3387-
: (tensor<2x32x32xf16>, tensor<2x32x32xf16>) ->
3388-
tensor<2x32x32xf16>
3389-
return %0 : tensor<2x32x32xf16>
3390-
})mlir",
3391-
context));
3375+
const char* kHlo = R"(
3376+
HloModule main
3377+
ENTRY main {
3378+
%p0 = f16[2,32,32] parameter(0)
3379+
ROOT %dot = f16[2,32,32] dot(%p0, %p0), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
3380+
}
3381+
)";
3382+
3383+
TF_ASSIGN_OR_RETURN(auto hlo_module,
3384+
ParseAndReturnUnverifiedModule(kHlo, {}));
3385+
xla::XlaComputation computation(hlo_module->ToProto());
3386+
33923387
std::unique_ptr<PjRtLoadedExecutable> executable;
3393-
if (use_xla_computation) {
3394-
XlaComputation computation;
3395-
TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation,
3396-
/*use_tuple_args=*/false,
3397-
/*return_tuple=*/false,
3398-
/*exec_build_options=*/nullptr));
3399-
TF_ASSIGN_OR_RETURN(executable,
3400-
client->CompileAndLoad(computation, compile_options));
3401-
} else {
3402-
TF_ASSIGN_OR_RETURN(executable,
3403-
client->CompileAndLoad(*module, compile_options));
3404-
}
3388+
TF_ASSIGN_OR_RETURN(executable,
3389+
client->CompileAndLoad(computation, compile_options));
34053390

34063391
const std::string optimized_hlo =
34073392
executable->GetExecutable()->GetHloModules()->front()->ToString();
@@ -3414,11 +3399,8 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
34143399

34153400
INSTANTIATE_TEST_SUITE_P(
34163401
ShardedAutotuningTest, ShardedAutotuningTest,
3417-
::testing::ValuesIn(std::vector<ShardedAutotuningTestInfo>{{true, 2, 0},
3418-
{false, 2, 0},
3419-
{false, 1, 0},
3420-
{false, 2, 1},
3421-
{false, 2, 2}}),
3402+
::testing::ValuesIn(std::vector<ShardedAutotuningTestInfo>{
3403+
{2, 0}, {1, 0}, {2, 1}, {2, 2}}),
34223404
ShardedAutotuningTestInfo::Name);
34233405

34243406
} // namespace
@@ -3437,7 +3419,6 @@ int main(int argc, char* argv[]) {
34373419
int num_active_nodes = -1;
34383420
int num_nodes_using_cache = -1;
34393421
std::string cache_dir;
3440-
bool use_xla_computation = false;
34413422

34423423
// Variables used by SuccessfulCrossHostTransfer.
34433424
std::string cross_host_test_role;
@@ -3458,8 +3439,6 @@ int main(int argc, char* argv[]) {
34583439
"Test parameter for ShardedAutotuningWorks."),
34593440
tsl::Flag("cache_dir", &cache_dir,
34603441
"Test parameter for ShardedAutotuningWorks."),
3461-
tsl::Flag("use_xla_computation", &use_xla_computation,
3462-
"Test parameter for ShardedAutotuningWorks."),
34633442

34643443
// Flags for SuccessfulCrossHostTransfer.
34653444
tsl::Flag("cross_host_test_role", &cross_host_test_role,
@@ -3480,8 +3459,7 @@ int main(int argc, char* argv[]) {
34803459

34813460
if (test_to_run == "ShardedAutotuningWorksHelper") {
34823461
absl::Status result = xla::ShardedAutotuningWorksTestBody(
3483-
node_id, num_active_nodes, num_nodes_using_cache, cache_dir,
3484-
use_xla_computation);
3462+
node_id, num_active_nodes, num_nodes_using_cache, cache_dir);
34853463
if (!result.ok()) {
34863464
LOG(ERROR) << result;
34873465
}

0 commit comments

Comments
 (0)