Skip to content

Commit 5b41cc2

Browse files
nvgrwGoogle-ML-Automation
authored andcommitted
Migrate sample_file_test to HloRunnerPjRt.
PiperOrigin-RevId: 820803579
1 parent 77643f3 commit 5b41cc2

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

xla/tests/BUILD

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3384,14 +3384,26 @@ xla_test(
33843384
srcs = ["sample_file_test.cc"],
33853385
backends = ["gpu"],
33863386
data = ["isolated_convolution.hlo"],
3387+
tags = [
3388+
"test_migrated_to_hlo_runner_pjrt",
3389+
],
33873390
deps = [
3388-
":hlo_test_base",
3391+
":hlo_pjrt_test_base",
3392+
":hlo_runner_agnostic_reference_mixin",
33893393
":xla_internal_test_main", # fixdeps: keep
3394+
"//xla:error_spec",
3395+
"//xla/hlo/ir:hlo",
33903396
"//xla/hlo/testlib:test",
3391-
"//xla/service:cpu_plugin", # reference backend
3392-
"//xla/service:platform_util",
3397+
"//xla/pjrt:pjrt_client",
3398+
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
3399+
"//xla/service:hlo_module_util",
3400+
"//xla/service:hlo_runner_interface",
3401+
"//xla/service:hlo_runner_pjrt",
3402+
"//xla/tsl/platform:statusor",
3403+
"//xla/tsl/platform:test",
3404+
"@com_google_absl//absl/log",
3405+
"@com_google_absl//absl/status:statusor",
33933406
"@tsl//tsl/platform:path",
3394-
"@tsl//tsl/platform:test",
33953407
],
33963408
)
33973409

xla/tests/sample_file_test.cc

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,57 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
// This demonstrates how to use hlo_test_base to create a file based testcase
17-
// and compare results on gpu and cpu.
16+
// This demonstrates how to create a file-based test case and compare results
17+
// between gpu and cpu.
1818

19+
#include <memory>
1920
#include <string>
20-
#include <vector>
21+
#include <utility>
2122

23+
#include "absl/log/log.h"
24+
#include "absl/status/statusor.h"
25+
#include "xla/error_spec.h"
26+
#include "xla/hlo/ir/hlo_module.h"
2227
#include "xla/hlo/testlib/test.h"
23-
#include "xla/service/platform_util.h"
24-
#include "xla/tests/hlo_test_base.h"
28+
#include "xla/pjrt/pjrt_client.h"
29+
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
30+
#include "xla/service/hlo_module_util.h"
31+
#include "xla/service/hlo_runner_interface.h"
32+
#include "xla/service/hlo_runner_pjrt.h"
33+
#include "xla/tests/hlo_pjrt_test_base.h"
34+
#include "xla/tests/hlo_runner_agnostic_reference_mixin.h"
35+
#include "xla/tsl/platform/statusor.h"
36+
#include "xla/tsl/platform/test.h"
2537
#include "tsl/platform/path.h"
26-
#include "tsl/platform/test.h"
2738

2839
namespace xla {
2940
namespace {
3041

31-
class SampleFileTest : public HloTestBase {
42+
std::unique_ptr<HloRunnerInterface> GetReferenceRunner() {
43+
absl::StatusOr<std::unique_ptr<PjRtClient>> client = GetXlaPjrtCpuClient({});
44+
if (!client.ok()) {
45+
LOG(FATAL) << "Failed to create XLA:CPU PjRtClient: " << client.status();
46+
}
47+
return std::make_unique<HloRunnerPjRt>(*std::move(client));
48+
}
49+
50+
class SampleFileTest : public HloRunnerAgnosticReferenceMixin<HloPjRtTestBase> {
3251
protected:
3352
SampleFileTest()
34-
: HloTestBase(
35-
/*test_platform=*/PlatformUtil::GetPlatform("gpu").value(),
36-
/*reference_platform=*/PlatformUtil::GetPlatform("cpu").value()) {}
53+
: HloRunnerAgnosticReferenceMixin<HloPjRtTestBase>(
54+
/*reference_runner=*/GetReferenceRunner()) {}
3755
};
3856

3957
TEST_F(SampleFileTest, Convolution) {
40-
const std::string& filename = tsl::io::JoinPath(
58+
const std::string filename = tsl::io::JoinPath(
4159
tsl::testing::XlaSrcRoot(), "tests", "isolated_convolution.hlo");
42-
EXPECT_TRUE(RunAndCompareFromFile(filename, ErrorSpec{0.01}));
60+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
61+
ReadModuleFromHloTextFile(filename));
62+
module->mutable_config()
63+
.mutable_debug_options()
64+
.set_xla_cpu_parallel_codegen_split_count(1);
65+
66+
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01}));
4367
}
4468

4569
} // namespace

0 commit comments

Comments
 (0)