Skip to content

Commit f3892f0

Browse files
Fix issue #5544 that the double GRF mode is not used when build native binary. (#5560) (#5576)
To get the spill size from the zebin instead of from the output string from ocloc. It uses extra python package `pyelftools`. --------- (cherry picked from commit ba1d008) Signed-off-by: Lu,Chengjun <[email protected]> Co-authored-by: Lu, Chengjun <[email protected]>
1 parent be39540 commit f3892f0

File tree

3 files changed

+46
-12
lines changed

3 files changed

+46
-12
lines changed

python/test/unit/intel/test_native_code_generation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,18 @@ def kernel(X, SIZE: tl.constexpr):
1313

1414
x = to_triton(numpy_random(SIZE, dtype_str="bfloat16"), device=device, dst_type="bfloat16")
1515
kernel[(1, )](x, SIZE=SIZE, num_warps=4, generate_native_code=True)
16+
17+
18+
def test_auto_large_grf(device):
19+
SIZE = 1024
20+
21+
@triton.jit
22+
def kernel(X, SIZE: tl.constexpr):
23+
x = tl.arange(0, SIZE)
24+
y = tl.sort(x, descending=True)
25+
tl.store(X + x, y)
26+
27+
x = to_triton(numpy_random(SIZE, dtype_str="float32"), device=device, dst_type="float32")
28+
# Triton XPU will auto choose large GRF mode for grf_mode='default'
29+
k = kernel[(1, )](x, SIZE=SIZE, num_warps=1, generate_native_code=True, grf_mode='default')
30+
assert "-cl-intel-256-GRF-per-thread" in k.metadata.build_flags

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,7 @@ def get_triton_version_suffix():
856856
description="A language and compiler for custom Deep Learning operations",
857857
long_description="",
858858
install_requires=[
859+
"pyelftools",
859860
"importlib-metadata; python_version < '3.10'",
860861
],
861862
packages=list(get_packages()),

third_party/intel/backend/compiler.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
import hashlib
1212
import tempfile
1313
import signal
14+
import re
1415
import os
1516
import subprocess
1617
from pathlib import Path
18+
from elftools.elf.elffile import ELFFile
1719

1820
try: # XPUBackend allows metaclasses injection
1921
from .meta import XPUBackendMeta
@@ -68,6 +70,23 @@ def hash(self):
6870
return hashlib.sha256(key.encode("utf-8")).hexdigest()
6971

7072

73+
SPILL_SIZE_RE = re.compile(r'spill_size\s*[:=]\s*(\d+)')
74+
75+
76+
def extract_spill_size_from_zebin(file):
77+
with open(file, 'rb') as f:
78+
elf = ELFFile(f)
79+
zeinfo = elf.get_section_by_name(".ze_info")
80+
if zeinfo is None:
81+
raise RuntimeError('Internal Triton ZEBIN codegen error:'
82+
'Section .ze_info not found in zebin')
83+
text = zeinfo.data().decode('utf-8')
84+
match = SPILL_SIZE_RE.search(text)
85+
if match:
86+
return int(match.group(1))
87+
return 0
88+
89+
7190
class XPUBackend(BaseBackend, metaclass=XPUBackendMeta):
7291
arch_to_impl = {} # Architecture id to backend implementation class mapping
7392
binary_ext = "spv"
@@ -427,21 +446,20 @@ def make_zebin(cls, src, metadata, options):
427446

428447
ocloc_cmd = [
429448
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', cls.device_arch,
430-
'-options', metadata["build_flags"] + shader_dump_opt
449+
'-options', metadata['build_flags'] + shader_dump_opt
431450
]
432451

433452
try:
434-
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
435-
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
436-
"""
437-
The exact message is something like:
438-
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
439-
is "spilled" enough for now?
440-
"""
441-
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
442-
# re-run with new build flags
443-
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
444-
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
453+
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
454+
if options.grf_mode == 'default':
455+
spill_size = extract_spill_size_from_zebin(fbin)
456+
# The threshold of 1000 for spill_size is chosen based on empirical observations
457+
# and aligned with triton/backends/intel/driver.c
458+
if spill_size > 1000:
459+
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
460+
# re-run with double GRF mode
461+
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
462+
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
445463
except subprocess.CalledProcessError as e:
446464
if e.returncode == 255:
447465
error = 'Internal Triton ZEBIN codegen error'

0 commit comments

Comments
 (0)