Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def parse(full_name, ext, context):
return module
if ext == "llir" or ext == "ptx" or ext == "amdgcn":
return Path(full_name).read_text()
if ext == "cubin" or ext == "hsaco":
if ext == "cubin" or ext == "hsaco" or ext == "zebin":
return Path(full_name).read_bytes()
if ext == "spv":
return Path(full_name).read_bytes()
Expand Down Expand Up @@ -332,7 +332,7 @@ def compile(src, target=None, options=None, _env_vars=None):
print(f"\nOverriding kernel with file {full_name}")
next_module = parse(full_name, ext, context)
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json", "spv")):
if (not store_only_binary) or (ext in ("cubin", "hsaco", "zebin", "json", "spv")):
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
if fn_dump_manager is not None:
fn_dump_manager.put(next_module, ir_filename)
Expand Down Expand Up @@ -422,11 +422,15 @@ def __init__(self, src, metadata_group, hash):
self.name = self.metadata.name
# stores the text of each level of IR that was generated during compilation
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
binary_ext = backend.binary_ext
self.asm = AsmDict({
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
for file in asm_files
})

def read_file(path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the spv and zebin are in binary format. To dump the intermidate file either by text or binary format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's worth rewriting without exceptions? They usually work noticeably slower.

try:
return path.read_text()
except UnicodeDecodeError:
return path.read_bytes()

self.asm = AsmDict({file.suffix[1:]: read_file(file) for file in asm_files})
binary_ext = metadata.get("binary_ext", backend.binary_ext)
self.metadata_group = metadata_group
self.kernel = self.asm[binary_ext]
# binaries are lazily initialized
Expand Down
3 changes: 2 additions & 1 deletion python/triton/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def constexpr(s):
if hints.get((i, ), None) == 16:
suffix += 'd'
func_name = '_'.join([out_name, sig_hash, suffix])
asm = ccinfo.asm[backend.binary_ext] # store binary data once
binary_ext = getattr(ccinfo.metadata, "binary_ext", backend.binary_ext)
asm = ccinfo.asm[binary_ext] # store binary data once

hex_ = str(binascii.hexlify(asm))[2:-1]

Expand Down
85 changes: 45 additions & 40 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,55 +376,58 @@ def make_spv(src, metadata, options, device_arch):

if knobs.intel.disable_igc_opt:
metadata["build_flags"] += " -cl-opt-disable"
return spirv

@staticmethod
def make_zebin(src, metadata, options, device_arch):
metadata["binary_ext"] = "zebin"

shader_dump_opt = ""
if knobs.intel.dump_shader_info:
# The IGC (Intel Graphic Compiler) only parses the options at first time in JIT-ing the binary per process.
# Have to use the `ocloc` to generate the binary in sub-process to work around the limitation.
assert options.generate_native_code, "Only support native code generation with shader dump"
shader_dump_opt = f" -igc_opts ',DumpToCustomDir={metadata['cache_dir']},ShaderDumpEnable=1'"

metadata["generate_native_code"] = options.generate_native_code

if options.generate_native_code:
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
fsrc.write(spirv)
fbin = fsrc.name + '.o'

ocloc_cmd = [
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch,
'-options', metadata["build_flags"] + shader_dump_opt
]

try:
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
"""
The exact message is something like:
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
is "spilled" enough for now?
"""
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
# re-run with new build flags
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
except subprocess.CalledProcessError as e:
if e.returncode == 255:
error = 'Internal Triton ZEBIN codegen error'
elif e.returncode == 128 + signal.SIGSEGV:
error = '`ocloc` raised SIGSEGV'
else:
error = f'`ocloc` failed with error code {e.returncode}'

raise RuntimeError(f'{error}\n'
f'`ocloc` stderr:\n{e.output}\n'
f'Repro command: {ocloc_cmd}\n') from e

with open(fbin, 'rb') as f:
zebin = f.read()
return zebin
return spirv
with tempfile.TemporaryDirectory() as temp_dir:
with track("generate_native_code"), tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir,
delete=False) as fsrc:
fsrc.write(src)
fbin = fsrc.name + '.o'

ocloc_cmd = [
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options',
metadata["build_flags"] + shader_dump_opt
]

try:
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
"""
The exact message is something like:
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
is "spilled" enough for now?
"""
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
# re-run with new build flags
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
except subprocess.CalledProcessError as e:
if e.returncode == 255:
error = 'Internal Triton ZEBIN codegen error'
elif e.returncode == 128 + signal.SIGSEGV:
error = '`ocloc` raised SIGSEGV'
else:
error = f'`ocloc` failed with error code {e.returncode}'

raise RuntimeError(f'{error}\n'
f'`ocloc` stderr:\n{e.output}\n'
f'Repro command: {ocloc_cmd}\n') from e

with open(fbin, 'rb') as f:
zebin = f.read()
return zebin

def add_stages(self, stages, options, language):
if language == Language.TRITON:
Expand All @@ -434,6 +437,8 @@ def add_stages(self, stages, options, language):
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch)
if options.generate_native_code:
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch)
if knobs.runtime.add_stages_inspection_hook is not None:
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)

Expand Down
3 changes: 0 additions & 3 deletions third_party/intel/tools/intel/compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,6 @@ int32_t {kernel_name}(sycl::queue &stream, {signature}) {{
size_t global_range_y = {gridY};
size_t global_range_z = {gridZ};
size_t local_range_x = {num_warps} * {threads_per_warp};
if (driver_version.find("+") != std::string::npos) {{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code doesn't make sense. Remove it.

local_range_x = 16;
}}
size_t local_range_y = 1;
size_t local_range_z = 1;

Expand Down
Loading