@@ -3233,15 +3233,13 @@ absl::Status SuccessfulCrossHostTransferTestBody(bool is_sender,
32333233}
32343234
32353235struct ShardedAutotuningTestInfo {
3236- bool use_xla_computation;
32373236 int num_active_nodes;
32383237 int num_nodes_using_cache;
32393238
32403239 static std::string Name (
32413240 const ::testing::TestParamInfo<ShardedAutotuningTestInfo>& info) {
3242- return absl::StrFormat (
3243- " computation_%d_active_%d_cache_%d" , info.param .use_xla_computation ,
3244- info.param .num_active_nodes , info.param .num_nodes_using_cache );
3241+ return absl::StrFormat (" active_%d_cache_%d" , info.param .num_active_nodes ,
3242+ info.param .num_nodes_using_cache );
32453243 }
32463244};
32473245
@@ -3273,8 +3271,6 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
32733271 argv.push_back (" sharded_autotuning_test" );
32743272 argv.push_back (" --test_to_run=ShardedAutotuningWorksHelper" );
32753273 argv.push_back (absl::StrFormat (" --node_id=%d" , node_id));
3276- argv.push_back (absl::StrFormat (" --use_xla_computation=%d" ,
3277- param.use_xla_computation ));
32783274 argv.push_back (
32793275 absl::StrFormat (" --num_active_nodes=%d" , param.num_active_nodes ));
32803276 argv.push_back (absl::StrFormat (" --num_nodes_using_cache=%d" ,
@@ -3324,8 +3320,7 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
33243320absl::Status ShardedAutotuningWorksTestBody (const int node_id,
33253321 const int num_active_nodes,
33263322 const int num_nodes_using_cache,
3327- absl::string_view cache_dir,
3328- bool use_xla_computation) {
3323+ absl::string_view cache_dir) {
33293324 std::unique_ptr<xla::DistributedRuntimeService> service;
33303325 if (node_id == 0 ) {
33313326 TF_ASSIGN_OR_RETURN (
@@ -3377,31 +3372,21 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
33773372 debug_options.set_xla_gpu_per_fusion_autotune_cache_dir (cache_dir);
33783373 }
33793374
3380- mlir::MLIRContext context;
3381- TF_ASSIGN_OR_RETURN (mlir::OwningOpRef<mlir::ModuleOp> module ,
3382- ParseMlirModuleString ( R"mlir(
3383- func.func public @main(%arg0: tensor<2x32x32xf16>) ->
3384- (tensor<2x32x32xf16> {jax.result_info = ""}) {
3385- %0 = stablehlo.dot_general %arg0, %arg0, batching_dims = [0] x [0],
3386- contracting_dims = [2] x [1]
3387- : (tensor<2x32x32xf16>, tensor<2x32x32xf16>) ->
3388- tensor<2x32x32xf16>
3389- return %0 : tensor<2x32x32xf16>
3390- } )mlir" ,
3391- context));
3375+ const char * kHlo = R"(
3376+ HloModule main
3377+ ENTRY main {
3378+ %p0 = f16[2,32,32] parameter(0)
3379+ ROOT %dot = f16[2,32,32] dot(%p0, %p0), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
3380+ }
3381+ )" ;
3382+
3383+ TF_ASSIGN_OR_RETURN ( auto hlo_module,
3384+ ParseAndReturnUnverifiedModule ( kHlo , {}));
3385+ xla::XlaComputation computation (hlo_module-> ToProto ());
3386+
33923387 std::unique_ptr<PjRtLoadedExecutable> executable;
3393- if (use_xla_computation) {
3394- XlaComputation computation;
3395- TF_RETURN_IF_ERROR (MlirToXlaComputation (*module , computation,
3396- /* use_tuple_args=*/ false ,
3397- /* return_tuple=*/ false ,
3398- /* exec_build_options=*/ nullptr ));
3399- TF_ASSIGN_OR_RETURN (executable,
3400- client->CompileAndLoad (computation, compile_options));
3401- } else {
3402- TF_ASSIGN_OR_RETURN (executable,
3403- client->CompileAndLoad (*module , compile_options));
3404- }
3388+ TF_ASSIGN_OR_RETURN (executable,
3389+ client->CompileAndLoad (computation, compile_options));
34053390
34063391 const std::string optimized_hlo =
34073392 executable->GetExecutable ()->GetHloModules ()->front ()->ToString ();
@@ -3414,11 +3399,8 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
34143399
34153400INSTANTIATE_TEST_SUITE_P (
34163401 ShardedAutotuningTest, ShardedAutotuningTest,
3417- ::testing::ValuesIn (std::vector<ShardedAutotuningTestInfo>{{true , 2 , 0 },
3418- {false , 2 , 0 },
3419- {false , 1 , 0 },
3420- {false , 2 , 1 },
3421- {false , 2 , 2 }}),
3402+ ::testing::ValuesIn (std::vector<ShardedAutotuningTestInfo>{
3403+ {2 , 0 }, {1 , 0 }, {2 , 1 }, {2 , 2 }}),
34223404 ShardedAutotuningTestInfo::Name);
34233405
34243406} // namespace
@@ -3437,7 +3419,6 @@ int main(int argc, char* argv[]) {
34373419 int num_active_nodes = -1 ;
34383420 int num_nodes_using_cache = -1 ;
34393421 std::string cache_dir;
3440- bool use_xla_computation = false ;
34413422
34423423 // Variables used by SuccessfulCrossHostTransfer.
34433424 std::string cross_host_test_role;
@@ -3458,8 +3439,6 @@ int main(int argc, char* argv[]) {
34583439 " Test parameter for ShardedAutotuningWorks." ),
34593440 tsl::Flag (" cache_dir" , &cache_dir,
34603441 " Test parameter for ShardedAutotuningWorks." ),
3461- tsl::Flag (" use_xla_computation" , &use_xla_computation,
3462- " Test parameter for ShardedAutotuningWorks." ),
34633442
34643443 // Flags for SuccessfulCrossHostTransfer.
34653444 tsl::Flag (" cross_host_test_role" , &cross_host_test_role,
@@ -3480,8 +3459,7 @@ int main(int argc, char* argv[]) {
34803459
34813460 if (test_to_run == " ShardedAutotuningWorksHelper" ) {
34823461 absl::Status result = xla::ShardedAutotuningWorksTestBody (
3483- node_id, num_active_nodes, num_nodes_using_cache, cache_dir,
3484- use_xla_computation);
3462+ node_id, num_active_nodes, num_nodes_using_cache, cache_dir);
34853463 if (!result.ok ()) {
34863464 LOG (ERROR) << result;
34873465 }
0 commit comments