Skip to content

Commit f9f4b06

Browse files
BabakkGraphcorepbchekin
authored andcommitted
Use llvm filecheck instead of filecheck python lib. (#7070)
The filecheck python dependency isn't available in python3.9. This change removes it in lieu of the llvm FileCheck binary, which is now packaged into the wheel.
1 parent 9f3e3fe commit f9f4b06

File tree

5 files changed

+40
-20
lines changed

5 files changed

+40
-20
lines changed

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,20 @@ if(TRITON_BUILD_PYTHON_MODULE)
321321
target_link_libraries(triton PRIVATE z)
322322
endif()
323323
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
324+
325+
if (NOT DEFINED LLVM_SYSPATH)
326+
message(FATAL_ERROR "LLVM_SYSPATH must be set.")
327+
endif()
328+
329+
if (NOT DEFINED TRITON_WHEEL_DIR)
330+
message(FATAL_ERROR "TRITON_WHEEL_DIR must be set.")
331+
endif()
332+
333+
configure_file(
334+
"${LLVM_SYSPATH}/bin/FileCheck"
335+
"${TRITON_WHEEL_DIR}/FileCheck"
336+
COPYONLY)
337+
324338
endif()
325339

326340
if (UNIX AND NOT APPLE)

python/test-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@ pytest-forked
66
pytest-xdist
77
scipy>=1.7.1
88
llnl-hatchet
9-
filecheck
109
expecttest

python/test/unit/test_filecheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ def test_kernel():
3232
# CHECK: %c42_i32
3333
anchor(scalar)
3434

35-
with pytest.raises(ValueError, match="Couldn't match \"%c42_i32\""):
35+
with pytest.raises(ValueError, match="expected string not found in input\n # CHECK: %c42_i32"):
3636
run_filecheck_test(test_kernel)

python/triton/_filecheck.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
import sys
21
import os
3-
import io
42
import inspect
5-
6-
from filecheck.options import Options
7-
from filecheck.finput import FInput
8-
from filecheck.parser import Parser, pattern_for_opts
9-
from filecheck.matcher import Matcher
3+
import subprocess
4+
import tempfile
105

116
import triton
127
from triton.compiler import ASTSource, make_backend
@@ -22,8 +17,8 @@
2217
stub_target = GPUTarget("cuda", 100, 32)
2318
stub_backend = make_backend(stub_target)
2419

25-
llvm_bin_dir = os.path.join(os.path.dirname(sys.executable), "bin")
26-
filecheck_path = os.path.join(llvm_bin_dir, "FileCheck")
20+
triton_dir = os.path.dirname(__file__)
21+
filecheck_path = os.path.join(triton_dir, "FileCheck")
2722

2823

2924
class MatchError(ValueError):
@@ -37,14 +32,21 @@ def __str__(self):
3732

3833

3934
def run_filecheck(name, module_str, check_template):
40-
options = Options(match_filename=name)
41-
fin = FInput(name, module_str)
42-
ops = io.StringIO(check_template)
43-
parser = Parser(options, ops, *pattern_for_opts(options))
44-
matcher = Matcher(options, fin, parser)
45-
matcher.stderr = io.StringIO()
46-
if matcher.run() != 0:
47-
raise MatchError(matcher.stderr.getvalue(), module_str)
35+
with tempfile.TemporaryDirectory() as tempdir:
36+
temp_module = os.path.join(tempdir, "module")
37+
with open(temp_module, "w") as temp:
38+
temp.write(module_str)
39+
40+
temp_expected = os.path.join(tempdir, "expected")
41+
with open(temp_expected, "w") as temp:
42+
temp.write(check_template)
43+
44+
try:
45+
subprocess.check_output([filecheck_path, temp_expected, "--input-file", temp_module],
46+
stderr=subprocess.STDOUT)
47+
except subprocess.CalledProcessError as error:
48+
decoded = error.output.decode('unicode_escape')
49+
raise ValueError(decoded)
4850

4951

5052
def run_parser(kernel_fn):

setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ def get_thirdparty_packages(packages: list):
317317
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
318318
if p.lib_flag:
319319
thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
320+
if p.syspath_var_name:
321+
thirdparty_cmake_args.append(f"-D{p.syspath_var_name}={package_dir}")
320322
if p.sym_name is not None:
321323
sym_link_path = os.path.join(package_root_dir, p.sym_name)
322324
update_symlink(sym_link_path, package_dir)
@@ -453,6 +455,8 @@ def build_extension(self, ext):
453455
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
454456
thirdparty_cmake_args += self.get_pybind11_cmake_args()
455457
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
458+
wheeldir = os.path.dirname(extdir)
459+
456460
# create build directories
457461
if not os.path.exists(self.build_temp):
458462
os.makedirs(self.build_temp)
@@ -466,7 +470,8 @@ def build_extension(self, ext):
466470
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_PYTHON_MODULE=ON",
467471
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, "-DPython3_INCLUDE_DIR=" + python_include_dir,
468472
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
469-
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
473+
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]),
474+
"-DTRITON_WHEEL_DIR=" + wheeldir
470475
]
471476
if lit_dir is not None:
472477
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)

0 commit comments

Comments
 (0)