Skip to content

Commit 180404a

Browse files
gnskxxt
authored andcommitted
[XLA:CPU] Add support for riscv64
Co-authored-by: Levi Zim <[email protected]>
1 parent 5525a3f commit 180404a

File tree

16 files changed

+205
-3
lines changed

16 files changed

+205
-3
lines changed

third_party/xla/third_party/hwloc/hwloc.BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ cc_library(
272272
"hwloc/topology-linux.c",
273273
"include/hwloc/linux.h",
274274
],
275+
"@local_xla//xla/tsl:linux_riscv64": [
276+
"hwloc/topology-linux.c",
277+
"include/hwloc/linux.h",
278+
],
275279
"@local_xla//xla/tsl:linux_s390x": [
276280
"hwloc/topology-linux.c",
277281
"include/hwloc/linux.h",

third_party/xla/third_party/llvm/toolchains.patch

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ diff --git a/utils/bazel/llvm-project-overlay/llvm/config.bzl b/utils/bazel/llvm
4444
index 2e3bff53ead9..8d01617effdc 100644
4545
--- a/utils/bazel/llvm-project-overlay/llvm/config.bzl
4646
+++ b/utils/bazel/llvm-project-overlay/llvm/config.bzl
47-
@@ -98,8 +98,9 @@ builtin_thread_pointer = select({
47+
@@ -98,8 +98,10 @@ builtin_thread_pointer = select({
4848
# TODO: We should split out host vs. target here.
4949
llvm_config_defines = os_defines + builtin_thread_pointer + select({
5050
"@bazel_tools//src/conditions:windows": native_arch_defines("X86", "x86_64-pc-win32"),
@@ -53,6 +53,7 @@ index 2e3bff53ead9..8d01617effdc 100644
5353
+ "//llvm:macos_arm64": native_arch_defines("AArch64", "arm64-apple-darwin"),
5454
+ "//llvm:macos_x86_64": native_arch_defines("X86", "x86_64-unknown-darwin"),
5555
+ "//llvm:macos_x86_64_default": native_arch_defines("X86", "x86_64-unknown-darwin"),
56+
+ "@bazel_tools//src/conditions:linux_riscv64": native_arch_defines("RISCV", "riscv64-unknown-linux-gnu"),
5657
"@bazel_tools//src/conditions:linux_aarch64": native_arch_defines("AArch64", "aarch64-unknown-linux-gnu"),
5758
"@bazel_tools//src/conditions:linux_ppc64le": native_arch_defines("PowerPC", "powerpc64le-unknown-linux-gnu"),
5859
"@bazel_tools//src/conditions:linux_s390x": native_arch_defines("SystemZ", "systemz-unknown-linux_gnu"),

third_party/xla/third_party/mkl_dnn/mkldnn_v1.BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ cc_library(
220220
"@local_xla//xla/tsl:linux_aarch64": ["-lrt"],
221221
"@local_xla//xla/tsl:linux_x86_64": ["-lrt"],
222222
"@local_xla//xla/tsl:linux_ppc64le": ["-lrt"],
223+
"@local_xla//xla/tsl:linux_riscv64": ["-lrt"],
223224
"//conditions:default": [],
224225
}),
225226
textual_hdrs = _TEXTUAL_HDRS_LIST,

third_party/xla/third_party/py/manylinux_compliance_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def parse_args():
4646
required=True,
4747
help="ManyLinux compliance tag for ppc64le",
4848
)
49+
parser.add_argument(
50+
"--riscv64-compliance-tag",
51+
required=True,
52+
help="ManyLinux compliance tag for riscv64",
53+
)
4954
return parser.parse_args()
5055

5156

@@ -106,7 +111,7 @@ def verify_manylinux_compliance(
106111

107112
def test_manylinux_compliance(args):
108113
machine_type = platform.uname().machine
109-
supported_machine_types = ["x86_64", "aarch64", "ppc64le"]
114+
supported_machine_types = ["x86_64", "aarch64", "ppc64le", "riscv64"]
110115
if machine_type not in supported_machine_types:
111116
raise RuntimeError(
112117
"Unsupported machine type {machine_type}. The supported are:"
@@ -118,8 +123,10 @@ def test_manylinux_compliance(args):
118123
compliance_tag = args.x86_64_compliance_tag
119124
elif machine_type == "aarch64":
120125
compliance_tag = args.aarch64_compliance_tag
121-
else:
126+
elif machine_type == "ppc64le":
122127
compliance_tag = args.ppc64le_compliance_tag
128+
else: # machine_type == "riscv64"
129+
compliance_tag = args.riscv64_compliance_tag
123130
auditwheel_output = get_auditwheel_output(args.wheel_path)
124131
verify_manylinux_compliance(
125132
auditwheel_output,

third_party/xla/third_party/py/py_manylinux_compliance_test.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def verify_manylinux_compliance_test(
88
aarch64_compliance_tag,
99
x86_64_compliance_tag,
1010
ppc64le_compliance_tag,
11+
riscv64_compliance_tag,
1112
test_tags = []):
1213
py_test(
1314
name = name,
@@ -21,6 +22,7 @@ def verify_manylinux_compliance_test(
2122
"--aarch64-compliance-tag={}".format(aarch64_compliance_tag),
2223
"--x86_64-compliance-tag={}".format(x86_64_compliance_tag),
2324
"--ppc64le-compliance-tag={}".format(ppc64le_compliance_tag),
25+
"--riscv64-compliance-tag={}".format(riscv64_compliance_tag),
2426
],
2527
main = "manylinux_compliance_test.py",
2628
tags = ["manual"] + test_tags,

third_party/xla/third_party/py/python_init_rules.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ def python_init_rules(extra_patches = []):
4444
Label("//third_party/py:rules_python_pip_version.patch"),
4545
Label("//third_party/py:rules_python_freethreaded.patch"),
4646
Label("//third_party/py:rules_python_versions.patch"),
47+
Label("//third_party/py:rules_python_riscv64_pypi.patch"),
4748
] + extra_patches,
4849
)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
From: Levi Zim <[email protected]>
2+
Date: Tue, 14 Oct 2025 19:45:36 +0800
3+
Subject: [PATCH] fix: Add linux_riscv64 to _pip_repository_impl
4+
5+
Fix https://github.com/bazel-contrib/rules_python/discussions/2729
6+
---
7+
python/private/pypi/pip_repository.bzl | 1 +
8+
python/private/pypi/whl_installer/platform.py | 3 +++
9+
python/private/pypi/whl_target_platforms.bzl | 1 +
10+
tests/pypi/whl_installer/platform_test.py | 6 +++---
11+
.../whl_target_platforms/whl_target_platforms_tests.bzl | 9 +++++++++
12+
5 files changed, 17 insertions(+), 3 deletions(-)
13+
14+
diff --git a/python/private/pypi/pip_repository.bzl b/python/private/pypi/pip_repository.bzl
15+
index e9a4c44da3..d635651039 100644
16+
--- a/python/private/pypi/pip_repository.bzl
17+
+++ b/python/private/pypi/pip_repository.bzl
18+
@@ -96,6 +96,7 @@ def _pip_repository_impl(rctx):
19+
"linux_aarch64",
20+
"linux_arm",
21+
"linux_ppc",
22+
+ "linux_riscv64",
23+
"linux_s390x",
24+
"linux_x86_64",
25+
"osx_aarch64",
26+
diff --git a/python/private/pypi/whl_installer/platform.py b/python/private/pypi/whl_installer/platform.py
27+
index ff267fe4aa..0757d86990 100644
28+
--- a/python/private/pypi/whl_installer/platform.py
29+
+++ b/python/private/pypi/whl_installer/platform.py
30+
@@ -45,6 +45,7 @@ class Arch(Enum):
31+
ppc64le = 5
32+
s390x = 6
33+
arm = 7
34+
+ riscv64 = 8
35+
amd64 = x86_64
36+
arm64 = aarch64
37+
i386 = x86_32
38+
@@ -269,6 +270,8 @@ def platform_machine(self) -> str:
39+
return "ppc"
40+
elif self.arch == Arch.ppc64le:
41+
return "ppc64le"
42+
+ elif self.arch == Arch.riscv64:
43+
+ return "riscv64"
44+
elif self.arch == Arch.s390x:
45+
return "s390x"
46+
else:
47+
diff --git a/python/private/pypi/whl_target_platforms.bzl b/python/private/pypi/whl_target_platforms.bzl
48+
index 6c3dd5da83..28547c679c 100644
49+
--- a/python/private/pypi/whl_target_platforms.bzl
50+
+++ b/python/private/pypi/whl_target_platforms.bzl
51+
@@ -30,6 +30,7 @@ _CPU_ALIASES = {
52+
"ppc": "ppc",
53+
"ppc64": "ppc",
54+
"ppc64le": "ppc64le",
55+
+ "riscv64": "riscv64",
56+
"s390x": "s390x",
57+
"arm": "arm",
58+
"armv6l": "arm",
59+
diff --git a/tests/pypi/whl_installer/platform_test.py b/tests/pypi/whl_installer/platform_test.py
60+
index ad65650779..0d944bb196 100644
61+
--- a/tests/pypi/whl_installer/platform_test.py
62+
+++ b/tests/pypi/whl_installer/platform_test.py
63+
@@ -38,17 +38,17 @@ def test_can_get_specific_from_string(self):
64+
65+
def test_can_get_all_for_py_version(self):
66+
cp39 = Platform.all(minor_version=9, micro_version=0)
67+
- self.assertEqual(21, len(cp39), f"Got {cp39}")
68+
+ self.assertEqual(24, len(cp39), f"Got {cp39}")
69+
self.assertEqual(cp39, Platform.from_string("cp39.0_*"))
70+
71+
def test_can_get_all_for_os(self):
72+
linuxes = Platform.all(OS.linux, minor_version=9)
73+
- self.assertEqual(7, len(linuxes))
74+
+ self.assertEqual(8, len(linuxes))
75+
self.assertEqual(linuxes, Platform.from_string("cp39_linux_*"))
76+
77+
def test_can_get_all_for_os_for_host_python(self):
78+
linuxes = Platform.all(OS.linux)
79+
- self.assertEqual(7, len(linuxes))
80+
+ self.assertEqual(8, len(linuxes))
81+
self.assertEqual(linuxes, Platform.from_string("linux_*"))
82+
83+
def test_platform_sort(self):
84+
diff --git a/tests/pypi/whl_target_platforms/whl_target_platforms_tests.bzl b/tests/pypi/whl_target_platforms/whl_target_platforms_tests.bzl
85+
index a976a0cf95..8b7f0ad02b 100644
86+
--- a/tests/pypi/whl_target_platforms/whl_target_platforms_tests.bzl
87+
+++ b/tests/pypi/whl_target_platforms/whl_target_platforms_tests.bzl
88+
@@ -34,6 +34,9 @@ def _test_simple(env):
89+
"musllinux_1_1_ppc64le": [
90+
struct(os = "linux", cpu = "ppc64le", abi = None, target_platform = "linux_ppc64le", version = (1, 1)),
91+
],
92+
+ "musllinux_1_2_riscv64": [
93+
+ struct(os = "linux", cpu = "riscv64", abi = None, target_platform = "linux_riscv64", version = (1, 2)),
94+
+ ],
95+
"win_amd64": [
96+
struct(os = "windows", cpu = "x86_64", abi = None, target_platform = "windows_x86_64", version = (0, 0)),
97+
],
98+
@@ -66,6 +69,9 @@ def _test_with_abi(env):
99+
"musllinux_1_1_ppc64le": [
100+
struct(os = "linux", cpu = "ppc64le", abi = "cp311", target_platform = "cp311_linux_ppc64le", version = (1, 1)),
101+
],
102+
+ "musllinux_1_2_riscv64": [
103+
+ struct(os = "linux", cpu = "riscv64", abi = "cp311", target_platform = "cp311_linux_riscv64", version = (1, 2)),
104+
+ ],
105+
"win_amd64": [
106+
struct(os = "windows", cpu = "x86_64", abi = "cp311", target_platform = "cp311_windows_x86_64", version = (0, 0)),
107+
],
108+
@@ -96,6 +102,7 @@ def _can_parse_existing_tags(env):
109+
"manylinux2014_i686": 1,
110+
"manylinux2014_ppc64": 1,
111+
"manylinux2014_ppc64le": 1,
112+
+ "manylinux2014_riscv64": 1,
113+
"manylinux2014_s390x": 1,
114+
"manylinux2014_x86_64": 1,
115+
"manylinux_11_12_aarch64": 1,
116+
@@ -103,6 +110,7 @@ def _can_parse_existing_tags(env):
117+
"manylinux_11_12_i686": 1,
118+
"manylinux_11_12_ppc64": 1,
119+
"manylinux_11_12_ppc64le": 1,
120+
+ "manylinux_11_12_riscv64": 1,
121+
"manylinux_11_12_s390x": 1,
122+
"manylinux_11_12_x86_64": 1,
123+
"manylinux_1_2_aarch64": 1,
124+
@@ -111,6 +119,7 @@ def _can_parse_existing_tags(env):
125+
"musllinux_11_12_armv7l": 1,
126+
"musllinux_11_12_i686": 1,
127+
"musllinux_11_12_ppc64le": 1,
128+
+ "musllinux_11_12_riscv64": 1,
129+
"musllinux_11_12_s390x": 1,
130+
"musllinux_11_12_x86_64": 1,
131+
"win32": 1,
132+
133+
From f03ff72726015e64aa3e9af6adf34cb651a92c88 Mon Sep 17 00:00:00 2001
134+
From: Levi Zim <[email protected]>
135+
Date: Tue, 14 Oct 2025 20:26:34 +0800
136+
Subject: [PATCH 2/2] Empty commit to retry CI
137+
138+
CI fails with connection timed out

third_party/xla/xla/backends/cpu/codegen/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ load(
55
"//xla/tsl/platform:build_config_root.bzl",
66
"if_llvm_aarch64_available",
77
"if_llvm_powerpc_available",
8+
"if_llvm_riscv_available",
89
"if_llvm_system_z_available",
910
"if_llvm_x86_available",
1011
)
@@ -129,6 +130,8 @@ xla_cc_test(
129130
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
130131
]) + if_llvm_powerpc_available([
131132
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
133+
]) + if_llvm_riscv_available([
134+
"@llvm-project//llvm:RISCVCodeGen", # fixdeps: keep
132135
]) + if_llvm_system_z_available([
133136
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
134137
]) + if_llvm_x86_available([
@@ -249,6 +252,8 @@ cc_library(
249252
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
250253
]) + if_llvm_powerpc_available([
251254
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
255+
]) + if_llvm_riscv_available([
256+
"@llvm-project//llvm:RISCVCodeGen", # fixdeps: keep
252257
]) + if_llvm_system_z_available([
253258
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
254259
]) + if_llvm_x86_available([

third_party/xla/xla/backends/cpu/codegen/tools/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ load(
44
"//xla/tsl/platform:build_config_root.bzl",
55
"if_llvm_aarch64_available",
66
"if_llvm_powerpc_available",
7+
"if_llvm_riscv_available",
78
"if_llvm_system_z_available",
89
"if_llvm_x86_available",
910
)
@@ -38,6 +39,8 @@ cc_library(
3839
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
3940
]) + if_llvm_powerpc_available([
4041
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
42+
]) + if_llvm_riscv_available([
43+
"@llvm-project//llvm:RISCVCodeGen", # fixdeps: keep
4144
]) + if_llvm_system_z_available([
4245
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
4346
]) + if_llvm_x86_available([

third_party/xla/xla/codegen/intrinsic/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ load(
33
"//xla/tsl/platform:build_config_root.bzl",
44
"if_llvm_aarch64_available",
55
"if_llvm_powerpc_available",
6+
"if_llvm_riscv_available",
67
"if_llvm_system_z_available",
78
"if_llvm_x86_available",
89
)
@@ -157,6 +158,9 @@ cc_library(
157158
]) + if_llvm_powerpc_available([
158159
"@llvm-project//llvm:PowerPCAsmParser", # fixdeps: keep
159160
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
161+
]) + if_llvm_riscv_available([
162+
"@llvm-project//llvm:RISCVAsmParser", # fixdeps: keep
163+
"@llvm-project//llvm:RISCVCodeGen", # fixdeps: keep
160164
]) + if_llvm_system_z_available([
161165
"@llvm-project//llvm:SystemZAsmParser", # fixdeps: keep
162166
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep

0 commit comments

Comments
 (0)