@@ -13,33 +13,57 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- // This demonstrates how to use hlo_test_base to create a file based testcase
17- // and compare results on gpu and cpu.
16+ // This demonstrates how to create a file- based test case and compare results
17+ // between gpu and cpu.
1818
19+ #include < memory>
1920#include < string>
20- #include < vector >
21+ #include < utility >
2122
23+ #include " absl/log/log.h"
24+ #include " absl/status/statusor.h"
25+ #include " xla/error_spec.h"
26+ #include " xla/hlo/ir/hlo_module.h"
2227#include " xla/hlo/testlib/test.h"
23- #include " xla/service/platform_util.h"
24- #include " xla/tests/hlo_test_base.h"
28+ #include " xla/pjrt/pjrt_client.h"
29+ #include " xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
30+ #include " xla/service/hlo_module_util.h"
31+ #include " xla/service/hlo_runner_interface.h"
32+ #include " xla/service/hlo_runner_pjrt.h"
33+ #include " xla/tests/hlo_pjrt_test_base.h"
34+ #include " xla/tests/hlo_runner_agnostic_reference_mixin.h"
35+ #include " xla/tsl/platform/statusor.h"
36+ #include " xla/tsl/platform/test.h"
2537#include " tsl/platform/path.h"
26- #include " tsl/platform/test.h"
2738
2839namespace xla {
2940namespace {
3041
31- class SampleFileTest : public HloTestBase {
42+ std::unique_ptr<HloRunnerInterface> GetReferenceRunner () {
43+ absl::StatusOr<std::unique_ptr<PjRtClient>> client = GetXlaPjrtCpuClient ({});
44+ if (!client.ok ()) {
45+ LOG (FATAL) << " Failed to create XLA:CPU PjRtClient: " << client.status ();
46+ }
47+ return std::make_unique<HloRunnerPjRt>(*std::move (client));
48+ }
49+
50+ class SampleFileTest : public HloRunnerAgnosticReferenceMixin <HloPjRtTestBase> {
3251 protected:
3352 SampleFileTest ()
34- : HloTestBase(
35- /* test_platform=*/ PlatformUtil::GetPlatform(" gpu" ).value(),
36- /* reference_platform=*/ PlatformUtil::GetPlatform(" cpu" ).value()) {}
53+ : HloRunnerAgnosticReferenceMixin<HloPjRtTestBase>(
54+ /* reference_runner=*/ GetReferenceRunner()) {}
3755};
3856
3957TEST_F (SampleFileTest, Convolution) {
40- const std::string& filename = tsl::io::JoinPath (
58+ const std::string filename = tsl::io::JoinPath (
4159 tsl::testing::XlaSrcRoot (), " tests" , " isolated_convolution.hlo" );
42- EXPECT_TRUE (RunAndCompareFromFile (filename, ErrorSpec{0.01 }));
60+ TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
61+ ReadModuleFromHloTextFile (filename));
62+ module ->mutable_config ()
63+ .mutable_debug_options ()
64+ .set_xla_cpu_parallel_codegen_split_count (1 );
65+
66+ EXPECT_TRUE (RunAndCompare (std::move (module ), ErrorSpec{0.01 }));
4367}
4468
4569} // namespace
0 commit comments