Skip to content

Commit c099a86

Browse files
authored
Merge branch 'main' into etiotto.remove_masks
2 parents 9140e7f + 9e23713 commit c099a86

File tree

10 files changed

+99
-87
lines changed

10 files changed

+99
-87
lines changed

python/test/unit/language/test_core.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,8 +1352,6 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
13521352
pytest.xfail("Only test atomic bfloat16/float16 ops on GPU")
13531353
if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]:
13541354
pytest.xfail("uint cannot be negative")
1355-
if is_xpu() and dtype_x_str == 'bfloat16':
1356-
pytest.skip("bfloat16 not yet supported for xpu")
13571355

13581356
n_programs = 5
13591357

@@ -1442,8 +1440,6 @@ def kernel(X):
14421440
for check_return_val in ([True, False] if is_hip() else [True])])
14431441
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device):
14441442
check_type_supported(dtype_x_str, device)
1445-
if is_xpu() and dtype_x_str == 'bfloat16':
1446-
pytest.skip("bfloat16 not yet supported for xpu")
14471443
shape0, shape1 = shape
14481444
# triton kernel
14491445

@@ -1523,8 +1519,6 @@ def torch_to_triton_dtype(t):
15231519
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
15241520
def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device):
15251521
check_type_supported(dtype_x_str, device)
1526-
if is_xpu() and dtype_x_str == 'bfloat16':
1527-
pytest.skip("bfloat16 not yet supported for xpu")
15281522

15291523
@triton.jit
15301524
def kernel(X, val, NUM: tl.constexpr):
@@ -1549,8 +1543,6 @@ def kernel(X, val, NUM: tl.constexpr):
15491543
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
15501544
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
15511545
check_type_supported(dtype_x_str, device)
1552-
if is_xpu() and dtype_x_str == 'bfloat16':
1553-
pytest.skip("bfloat16 not yet supported for xpu")
15541546

15551547
@triton.jit
15561548
def kernel(X, val, NUM: tl.constexpr):
@@ -1587,9 +1579,6 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas
15871579
if is_interpreter():
15881580
pytest.xfail("not supported in the interpreter")
15891581

1590-
if is_xpu() and dtype_x_str == 'bfloat16':
1591-
pytest.skip("bfloat16 not yet supported for xpu")
1592-
15931582
@triton.jit
15941583
def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr):
15951584
xoffset = tl.program_id(0) * XBLOCK
@@ -5872,7 +5861,7 @@ def simple(data, out):
58725861

58735862
def test_num_ctas_pre_sm90(device):
58745863
if not is_cuda() and not is_hip():
5875-
pytest.skip("Only supported on CUDA and HIP")
5864+
pytest.xfail("Only supported on CUDA and HIP")
58765865

58775866
@triton.jit
58785867
def _kernel(src):

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,9 +1566,6 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
15661566
pytest.xfail("Multi-CTA not supported")
15671567
if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3:
15681568
pytest.skip("Broken on rocm")
1569-
if is_xpu():
1570-
if (kind, dtype_str) in [("add", "bfloat16")]:
1571-
pytest.skip("FIXME: issue #3914")
15721569

15731570
@triton.jit(debug=True)
15741571
def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr):

python/triton/compiler/compiler.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,19 @@ def parse(full_name, ext, context):
138138
return module
139139
if ext == "llir" or ext == "ptx" or ext == "amdgcn":
140140
return Path(full_name).read_text()
141-
if ext == "cubin" or ext == "hsaco":
141+
if ext == "cubin" or ext == "hsaco" or ext == "zebin":
142142
return Path(full_name).read_bytes()
143143
if ext == "spv":
144144
return Path(full_name).read_bytes()
145145

146146

147+
def read_file(full_name, ext):
148+
if ext in ["cubin", "hsaco", "spv", "zebin"]:
149+
return Path(full_name).read_bytes()
150+
else:
151+
return Path(full_name).read_text()
152+
153+
147154
def filter_traceback(e: BaseException):
148155
"""
149156
Removes code_generator.py and related files from tracebacks.
@@ -332,7 +339,7 @@ def compile(src, target=None, options=None, _env_vars=None):
332339
print(f"\nOverriding kernel with file {full_name}")
333340
next_module = parse(full_name, ext, context)
334341
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
335-
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json", "spv")):
342+
if (not store_only_binary) or (ext in ("cubin", "hsaco", "zebin", "json", "spv")):
336343
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
337344
if fn_dump_manager is not None:
338345
fn_dump_manager.put(next_module, ir_filename)
@@ -422,11 +429,9 @@ def __init__(self, src, metadata_group, hash):
422429
self.name = self.metadata.name
423430
# stores the text of each level of IR that was generated during compilation
424431
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
425-
binary_ext = backend.binary_ext
426-
self.asm = AsmDict({
427-
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
428-
for file in asm_files
429-
})
432+
433+
self.asm = AsmDict({file.suffix[1:]: read_file(file, file.suffix[1:]) for file in asm_files})
434+
binary_ext = metadata.get("binary_ext", backend.binary_ext)
430435
self.metadata_group = metadata_group
431436
self.kernel = self.asm[binary_ext]
432437
# binaries are lazily initialized

python/triton/tools/compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def constexpr(s):
183183
if hints.get((i, ), None) == 16:
184184
suffix += 'd'
185185
func_name = '_'.join([out_name, sig_hash, suffix])
186-
asm = ccinfo.asm[backend.binary_ext] # store binary data once
186+
binary_ext = getattr(ccinfo.metadata, "binary_ext", backend.binary_ext)
187+
asm = ccinfo.asm[binary_ext] # store binary data once
187188

188189
hex_ = str(binascii.hexlify(asm))[2:-1]
189190

scripts/skiplist/lts/language.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4665
22
python/test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float64-float64]
33
python/test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float64-float64]
4+
# Below bfloat16 tests require IGC 1188 or above
5+
python/test/unit/language/test_core.py::test_atomic_rmw[r".*bfloat16.*"]@regexp
6+
python/test/unit/language/test_core.py::test_tensor_atomic_rmw[r".*bfloat16.*"]@regexp
7+
python/test/unit/language/test_core.py::test_tensor_atomic_add_non_exclusive_offset[r".*bfloat16.*"]@regexp
8+
python/test/unit/language/test_core.py::test_tensor_atomic_add_shift_1[r".*bfloat16.*"]@regexp
9+
python/test/unit/language/test_core.py::test_tensor_atomic_add_access_patterns[r".*bfloat16.*"]@regexp
10+
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[r".*bfloat16.*"]@regexp

setup.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,15 @@ def build_extension(self, ext):
481481
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
482482
cmake_args.extend(thirdparty_cmake_args)
483483

484+
result = subprocess.run(["bash", "./scripts/capture-hw-details.sh"], stdout=subprocess.PIPE,
485+
stderr=subprocess.PIPE, check=True, text=True, env=os.environ.copy())
486+
agama_version = None
487+
for line in result.stdout.splitlines():
488+
if line.startswith("AGAMA_VERSION="):
489+
agama_version = line.split("=", 1)[1].strip()
490+
break
491+
cmake_args.append(f"-DAGAMA_VERSION={agama_version}")
492+
484493
# configuration
485494
cfg = get_build_type()
486495
build_args = ["--config", cfg]

third_party/intel/backend/compiler.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -376,55 +376,57 @@ def make_spv(src, metadata, options, device_arch):
376376

377377
if knobs.intel.disable_igc_opt:
378378
metadata["build_flags"] += " -cl-opt-disable"
379+
return spirv
380+
381+
@staticmethod
382+
def make_zebin(src, metadata, options, device_arch):
383+
metadata["binary_ext"] = "zebin"
379384

380385
shader_dump_opt = ""
381386
if knobs.intel.dump_shader_info:
382387
# The IGC (Intel Graphic Compiler) only parses the options at first time in JIT-ing the binary per process.
383388
# Have to use the `ocloc` to generate the binary in sub-process to work around the limitation.
384-
assert options.generate_native_code, "Only support native code generation with shader dump"
385389
shader_dump_opt = f" -igc_opts ',DumpToCustomDir={metadata['cache_dir']},ShaderDumpEnable=1'"
386390

387391
metadata["generate_native_code"] = options.generate_native_code
388392

389-
if options.generate_native_code:
390-
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
391-
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
392-
fsrc.write(spirv)
393-
fbin = fsrc.name + '.o'
394-
395-
ocloc_cmd = [
396-
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch,
397-
'-options', metadata["build_flags"] + shader_dump_opt
398-
]
399-
400-
try:
401-
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
402-
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
403-
"""
404-
The exact message is something like:
405-
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
406-
is "spilled" enough for now?
407-
"""
408-
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
409-
# re-run with new build flags
410-
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
411-
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
412-
except subprocess.CalledProcessError as e:
413-
if e.returncode == 255:
414-
error = 'Internal Triton ZEBIN codegen error'
415-
elif e.returncode == 128 + signal.SIGSEGV:
416-
error = '`ocloc` raised SIGSEGV'
417-
else:
418-
error = f'`ocloc` failed with error code {e.returncode}'
419-
420-
raise RuntimeError(f'{error}\n'
421-
f'`ocloc` stderr:\n{e.output}\n'
422-
f'Repro command: {ocloc_cmd}\n') from e
423-
424-
with open(fbin, 'rb') as f:
425-
zebin = f.read()
426-
return zebin
427-
return spirv
393+
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
394+
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
395+
fsrc.write(src)
396+
fbin = fsrc.name + '.o'
397+
398+
ocloc_cmd = [
399+
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options',
400+
metadata["build_flags"] + shader_dump_opt
401+
]
402+
403+
try:
404+
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
405+
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
406+
"""
407+
The exact message is something like:
408+
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
409+
is "spilled" enough for now?
410+
"""
411+
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
412+
# re-run with new build flags
413+
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
414+
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
415+
except subprocess.CalledProcessError as e:
416+
if e.returncode == 255:
417+
error = 'Internal Triton ZEBIN codegen error'
418+
elif e.returncode == 128 + signal.SIGSEGV:
419+
error = '`ocloc` raised SIGSEGV'
420+
else:
421+
error = f'`ocloc` failed with error code {e.returncode}'
422+
423+
raise RuntimeError(f'{error}\n'
424+
f'`ocloc` stderr:\n{e.output}\n'
425+
f'Repro command: {ocloc_cmd}\n') from e
426+
427+
with open(fbin, 'rb') as f:
428+
zebin = f.read()
429+
return zebin
428430

429431
def add_stages(self, stages, options, language):
430432
if language == Language.TRITON:
@@ -434,6 +436,8 @@ def add_stages(self, stages, options, language):
434436
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
435437
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
436438
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch)
439+
if options.generate_native_code:
440+
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch)
437441
if knobs.runtime.add_stages_inspection_hook is not None:
438442
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
439443

third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,30 @@ if (NOT SPIRVToLLVMTranslator_FOUND)
2626

2727
FetchContent_MakeAvailable(spirv-llvm-translator)
2828

29-
# FIXME: Don't apply patch when Agama driver is updated.
30-
execute_process(
31-
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
32-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
33-
ERROR_QUIET
34-
RESULT_VARIABLE PATCH_RESULT
35-
)
36-
if(PATCH_RESULT EQUAL 0)
29+
# FIXME: Don't apply patch when LTS driver is updated.
30+
if(DEFINED AGAMA_VERSION AND AGAMA_VERSION STREQUAL "1146")
3731
execute_process(
38-
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
39-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
40-
RESULT_VARIABLE PATCH_RESULT
32+
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
33+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
34+
ERROR_QUIET
35+
RESULT_VARIABLE PATCH_RESULT
4136
)
42-
else()
43-
execute_process( # Check if the patch is already applied
44-
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
45-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
46-
RESULT_VARIABLE PATCH_RESULT
47-
)
48-
endif()
49-
if(NOT PATCH_RESULT EQUAL 0)
50-
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
37+
if(PATCH_RESULT EQUAL 0)
38+
execute_process(
39+
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
40+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
41+
RESULT_VARIABLE PATCH_RESULT
42+
)
43+
else()
44+
execute_process( # Check if the patch is already applied
45+
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
46+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
47+
RESULT_VARIABLE PATCH_RESULT
48+
)
49+
endif()
50+
if(NOT PATCH_RESULT EQUAL 0)
51+
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
52+
endif()
5153
endif()
5254

5355
# FIXME: Don't apply patch when Agama driver is updated to incorporate with the SPV_INTEL_bfloat16_arithmetic extension.

third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class SmallVectorBuffer : public std::streambuf {
107107

108108
static SPIRV::TranslatorOpts getSPIRVOpts() {
109109
SPIRV::TranslatorOpts SPIRVOpts{SPIRV::VersionNumber::SPIRV_1_4};
110-
static constexpr std::array<SPIRV::ExtensionID, 18> AllowedExtensions{
110+
static constexpr std::array<SPIRV::ExtensionID, 19> AllowedExtensions{
111111
SPIRV::ExtensionID::SPV_EXT_shader_atomic_float_add,
112112
SPIRV::ExtensionID::SPV_INTEL_2d_block_io,
113113
SPIRV::ExtensionID::SPV_INTEL_arbitrary_precision_integers,
@@ -124,6 +124,7 @@ static SPIRV::TranslatorOpts getSPIRVOpts() {
124124
SPIRV::ExtensionID::SPV_INTEL_tensor_float32_conversion,
125125
SPIRV::ExtensionID::SPV_INTEL_unstructured_loop_controls,
126126
SPIRV::ExtensionID::SPV_INTEL_vector_compute,
127+
SPIRV::ExtensionID::SPV_KHR_bfloat16,
127128
SPIRV::ExtensionID::SPV_KHR_bit_instructions,
128129
SPIRV::ExtensionID::SPV_KHR_non_semantic_info};
129130

third_party/intel/tools/intel/compile.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,6 @@ int32_t {kernel_name}(sycl::queue &stream, {signature}) {{
138138
size_t global_range_y = {gridY};
139139
size_t global_range_z = {gridZ};
140140
size_t local_range_x = {num_warps} * {threads_per_warp};
141-
if (driver_version.find("+") != std::string::npos) {{
142-
local_range_x = 16;
143-
}}
144141
size_t local_range_y = 1;
145142
size_t local_range_z = 1;
146143

0 commit comments

Comments
 (0)