-
Notifications
You must be signed in to change notification settings - Fork 69
Description
Describe the bug
Hello,
I'm having problems using a simple Triton kernel taken from pytorch's inductor helpers. In Python, the code simply segfaults. When inspected on GDB, there's an error when handling target triple inside intel backend. The code in question is taken from [https://github.com/pytorch/pytorch/blob/release/2.7/torch/_inductor/runtime/triton_helpers.py] and condensed into the following script:
import triton
import triton.language as tl
@triton.jit
def promote_to_tensor(x):
return x + tl.zeros((1,), tl.int1)
@triton.jit
def is_floating(x):
return promote_to_tensor(x).dtype.is_floating()
@triton.jit
def maximum(a, b):
mask = a > b
if is_floating(a):
mask |= a != a
return tl.where(mask, a, b)
maximum[(1,)](0.3, 0.2)
The Triton was installed indirectly by vLLM distribution, which installed "3.3.1" version of pytorch-triton-xpu
. I tried resolving the issue of vLLM essentially separately installing triton
and pytorch-triton-xpu
, but the issue persists even after only one of these packages is installed.
The error occurs on the following combinations that were tested:
- Intel image 2025.0.1, Triton 3.3.0, 4x Intel GPU Max 1550
- Intel image 2025.0.1, Triton 3.3.0/1, 4x Intel GPU Max 1100
- Intel image 2025.0.2, Triton 3.3.0, 4x Intel GPU Max 1550
GDB returns the following python stack, indicating that the failure occurs when working on LLVM IR level in something related to the target triple:
(gdb) py-bt
Traceback (most recent call first):
<built-in method set_spv_target_triple of PyCapsule object at remote 0x7fffee5890e0>
File "/usr/local/lib/python3.11/dist-packages/triton/backends/intel/compiler.py", line 347, in make_llir
intel.set_spv_target_triple(llvm_mod)
File "/usr/local/lib/python3.11/dist-packages/triton/backends/intel/compiler.py", line 434, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
File "/usr/local/lib/python3.11/dist-packages/triton/compiler/compiler.py", line ?, in compile
(failed to get frame line number)
File "/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py", line 563, in run
kernel = self.compile(src, target=target, options=options.__dict__)
File "/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py", line 336, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/root/triton_repro.py", line 19, in <module>
maximum[(1,)](0.3, 0.2)
The error in question looks like some sort of uninitialized pointer to string, but I cannot pin that down. Either way, the GDB stack trace:
#0 std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >::operator= (this=this@entry=0xe8, __str=...) at /opt/rh/gcc-toolset-13/root/usr/include/c++/13/bits/basic_string.h:858
#1 0x00007fffef61c8fb in llvm::Module::setTargetTriple (T=..., this=<optimized out>) at /root/.triton/llvm/llvm-1188b1ff-almalinux-x64/include/llvm/IR/Module.h:341
#2 operator() (__closure=<optimized out>, mod=<optimized out>) at ../../../third_party/intel/triton_xpu.cc:277
#3 pybind11::detail::argument_loader<llvm::Module*>::call_impl<void, init_triton_intel(pybind11::module&&)::<lambda(llvm::Module*)>&, 0, pybind11::detail::void_type> (f=..., this=0x7fffffffd210)
at /opt/_internal/cpython-3.11.12/lib/python3.11/site-packages/pybind11/include/pybind11/cast.h:1624
#4 pybind11::detail::argument_loader<llvm::Module*>::call<void, pybind11::detail::void_type, init_triton_intel(pybind11::module&&)::<lambda(llvm::Module*)>&> (f=..., this=0x7fffffffd210)
at /opt/_internal/cpython-3.11.12/lib/python3.11/site-packages/pybind11/include/pybind11/cast.h:1598
#5 operator() (__closure=0x0, call=...) at /opt/python/cp311-cp311/lib/python3.11/site-packages/pybind11/include/pybind11/pybind11.h:297
#6 _FUN () at /opt/python/cp311-cp311/lib/python3.11/site-packages/pybind11/include/pybind11/pybind11.h:267
#7 0x00007fffef6356d0 in pybind11::cpp_function::dispatcher (self=<optimized out>, args_in=(None,), kwargs_in=0x0)
at /opt/python/cp311-cp311/lib/python3.11/site-packages/pybind11/include/pybind11/pybind11.h:987
#8 0x000000000055401b in cfunction_call (func=<built-in method set_spv_target_triple of PyCapsule object at remote 0x7fffee5890e0>, args=<optimized out>, kwargs=<optimized out>)
at ../Objects/methodobject.c:542
#9 0x000000000052ed63 in _PyObject_MakeTpCall (tstate=0xa53c78 <_PyRuntime+166328>, callable=<built-in method set_spv_target_triple of PyCapsule object at remote 0x7fffee5890e0>, args=<optimized out>,
nargs=<optimized out>, keywords=0x0) at ../Objects/call.c:214
#10 0x000000000053c61d in _PyEval_EvalFrameDefault (tstate=<optimized out>, frame=<optimized out>, throwflag=<optimized out>) at ../Python/ceval.c:4769
#11 0x0000000000583f68 in _PyEval_EvalFrame (throwflag=0, frame=0x7ffff7aaf120, tstate=0xa53c78 <_PyRuntime+166328>) at ../Include/internal/pycore_ceval.h:73
#12 _PyEval_Vector (kwnames=<optimized out>, argcount=<optimized out>, args=<optimized out>, locals=0x0, func=0x7fffee35bd80, tstate=0xa53c78 <_PyRuntime+166328>) at ../Python/ceval.c:6434
#13 _PyFunction_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, stack=<optimized out>, func=<function at remote 0x7fffee35bd80>) at ../Objects/call.c:393
#14 _PyObject_VectorcallTstate (tstate=0xa53c78 <_PyRuntime+166328>, callable=<function at remote 0x7fffee35bd80>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>)
at ../Include/internal/pycore_call.h:92
#15 0x00000000005838e6 in method_vectorcall (method=<optimized out>, args=0x7fffedfbb698, nargsf=<optimized out>, kwnames=<optimized out>) at ../Objects/classobject.c:59
#16 0x000000000056e9c7 in _PyVectorcall_Call (kwargs=<optimized out>, tuple=<optimized out>, callable=<method at remote 0x7fffedff0540>, func=0x5835e0 <method_vectorcall>,
tstate=0xa53c78 <_PyRuntime+166328>) at ../Objects/call.c:257
#17 _PyObject_Call (kwargs=<optimized out>, args=<optimized out>, callable=<method at remote 0x7fffedff0540>, tstate=0xa53c78 <_PyRuntime+166328>) at ../Objects/call.c:328
#18 PyObject_Call (callable=<method at remote 0x7fffedff0540>, args=<optimized out>, kwargs=<optimized out>) at ../Objects/call.c:355
#19 0x0000000000540ef7 in do_call_core (use_tracing=<optimized out>, kwdict={'grid': (1,), 'warmup': False}, callargs=(<float at remote 0x7ffff7b32c10>, <float at remote 0x7ffff7b31c10>),
func=<method at remote 0x7fffedff0540>, tstate=<optimized out>) at ../Python/ceval.c:7349
#20 _PyEval_EvalFrameDefault (tstate=<optimized out>, frame=<optimized out>, throwflag=<optimized out>) at ../Python/ceval.c:5376
#21 0x000000000060eefd in _PyEval_EvalFrame (throwflag=0, frame=0x7ffff7aaf020, tstate=0xa53c78 <_PyRuntime+166328>) at ../Include/internal/pycore_ceval.h:73
#22 _PyEval_Vector (tstate=tstate@entry=0xa53c78 <_PyRuntime+166328>, func=func@entry=0x7ffff7bedf80,
locals=locals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated), args=args@entry=0x0, argcount=argcount@entry=0, kwnames=kwnames@entry=0x0) at ../Python/ceval.c:6434
#23 0x000000000060e6a8 in PyEval_EvalCode (co=<code at remote 0x7ffff7bd0870>, globals=<optimized out>,
locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated)) at ../Python/ceval.c:1148
#24 0x000000000062dd2c in run_eval_code_obj (tstate=0xa53c78 <_PyRuntime+166328>, co=0x7ffff7bd0870,
globals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated),
locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n
>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated)) at ../Python/ceval.c:1148
#24 0x000000000062dd2c in run_eval_code_obj (tstate=0xa53c78 <_PyRuntime+166328>, co=0x7ffff7bd0870,
globals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated),
locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n--Type <RET> for more, q to quit, c to continue without paging--
r...(truncated)) at ../Python/pythonrun.c:1741
#25 0x0000000000629fd6 in run_mod (mod=<optimized out>, filename=<optimized out>,
globals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated),
locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated), flags=<optimized out>, arena=<optimized out>) at ../Python/pythonrun.c:1762
#26 0x000000000063e8a7 in pyrun_file (fp=fp@entry=0xa99370, filename=filename@entry='/root/triton_repro.py', start=start@entry=257,
globals=globals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated),
locals=locals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/root/triton_repro.py') at remote 0x7ffff7c0f410>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7ba0ae0>, '__file__': '/root/triton_repro.py', '__cached__': None, 'triton': <module at remote 0x7ffff79d3dd0>, 'tl': <module at remote 0x7fffee1664d0>, 'promote_to_tensor': <JITFunction(fn=<function at remote 0x7ffff7bb04a0>, module='__main__', version=None, signature=<Signature at remote 0x7fffedfbbbe0>, do_not_specialize=[], do_not_specialize_on_alignment=[], starting_line_number=4, _repr=None, _fn_name='promote_to_tensor', launch_metadata=None, params=[<KernelParam(num=0, _param=<Parameter at remote 0x7ffff79d9540>, do_not_specialize=False, do_not_specialize_on_alignment=False, name='x', annotation='', is_constexpr=False) at remote 0x7ffff79d8d10>], hash='b48e7000fd62858340eca23fc0adbba19bdf63f4aeb39a66d0a25631f05381674', src='def promote_to_tensor(x):\n r...(truncated), closeit=closeit@entry=1, flags=0x7fffffffdc68) at ../Python/pythonrun.c:1657
#27 0x000000000063df09 in _PyRun_SimpleFileObject (fp=fp@entry=0xa99370, filename=filename@entry='/root/triton_repro.py', closeit=closeit@entry=1, flags=flags@entry=0x7fffffffdc68)
at ../Python/pythonrun.c:440
#28 0x000000000063dc7f in _PyRun_AnyFileObject (fp=0xa99370, filename='/root/triton_repro.py', closeit=1, flags=0x7fffffffdc68) at ../Python/pythonrun.c:79
#29 0x0000000000638907 in pymain_run_file_obj (skip_source_first_line=0, filename='/root/triton_repro.py', program_name='/usr/bin/python') at ../Modules/main.c:360
#30 pymain_run_file (config=0xa39cc0 <_PyRuntime+59904>) at ../Modules/main.c:379
#31 pymain_run_python (exitcode=0x7fffffffdc64) at ../Modules/main.c:605
#32 Py_RunMain () at ../Modules/main.c:684
#33 0x00000000005ff95d in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at ../Modules/main.c:738
#34 0x00007ffff7c9d1ca in __libc_start_call_main (main=main@entry=0x5ff8b0 <main>, argc=argc@entry=2, argv=argv@entry=0x7fffffffdea8) at ../sysdeps/nptl/libc_start_call_main.h:58
#35 0x00007ffff7c9d28b in __libc_start_main_impl (main=0x5ff8b0 <main>, argc=2, argv=0x7fffffffdea8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffde98)
at ../csu/libc-start.c:360
#36 0x00000000005ff7e5 in _start ()
Failing assembly instruction:
Dump of assembler code for function _ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEaSEOS4_:
0x00007fffef616e90 <+0>: push %rbp
0x00007fffef616e91 <+1>: mov %rdi,%rbp
0x00007fffef616e94 <+4>: push %rbx
0x00007fffef616e95 <+5>: lea 0x10(%rbp),%rax
0x00007fffef616e99 <+9>: mov %rsi,%rbx
0x00007fffef616e9c <+12>: sub $0x8,%rsp
=> 0x00007fffef616ea0 <+16>: mov (%rdi),%rdi
0x00007fffef616ea3 <+19>: mov 0x8(%rsi),%rdx
Register states:
(gdb) info registers
rax 0xf8 248
rbx 0x7fffffffd270 140737488343664
rcx 0x3 3
rdx 0x16 22
rsi 0x7fffffffd270 140737488343664
rdi 0xe8 232
rbp 0xe8 0xe8
rsp 0x7fffffffd1f0 0x7fffffffd1f0
r8 0x7fff073bc000 140733314744320
r9 0x7ffff7b7b240 140737349399104
r10 0x0 0
r11 0x0 0
r12 0x0 0
r13 0x7fffffffd230 140737488343600
r14 0x7fffffffd280 140737488343680
r15 0x1 1
rip 0x7fffef616ea0 0x7fffef616ea0 <std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >::operator=(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&&)+16>
eflags 0x10206 [ PF IF RF ]
cs 0x33 51
ss 0x2b 43
ds 0x0 0
es 0x0 0
fs 0x0 0
gs 0x0 0
k0 0xc0 192
k1 0x4b 75
k2 0xff003fff 4278206463
k3 0xfffe000000000000 18446181123756130304
k4 0xffffffff 4294967295
k5 0x0 0
k6 0x0 0
k7 0x0 0
fs_base 0x7ffff7c6e740 140737350395712
gs_base 0x0 0
Locals:
(gdb) info locals
__equal_allocs = true
(gdb) info args
this = 0xe8
__str = @0x7fffffffd270: "spir64-unknown-unknown"
Dockerfile to reproduce the environment:
FROM intel/deep-learning-essentials:2025.0.1-0-devel-ubuntu24.04
RUN rm /etc/apt/sources.list.d/intel-graphics.list
RUN apt-get update -y
RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
RUN add-apt-repository ppa:deadsnakes/ppa
RUN apt-get install -y python3.11 python3.11-dev python3.11-dbg python3-pip
RUN apt-get install -y --no-install-recommends ccache git curl wget ca-certificates \
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
RUN ln -fs /usr/bin/python3.11 /usr/bin/python && ln -fs /usr/bin/python3.11 /usr/bin/python3
RUN apt-get install -y --no-install-recommends vim numactl gdb micro
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"
COPY common.txt .
COPY xpu.txt .
RUN pip install -r xpu.txt
WORKDIR /root
ENTRYPOINT ["/bin/bash"]
Environment details
Configuration from the Dockerfile above. Two base images tested: intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 and intel/deep-learning-essentials:2025.0.1-0-devel-ubuntu24.04
Information from the get_env script used by vLLM team on intel/deep-learning-essentials:2025.0.1-0-devel-ubuntu24.04
configuration with 4 Max 1550:
Collecting environment information...
=====================================
PyTorch version: 2.7.0+xpu
PyTorch CXX11 ABI: Yes
IPEX version: 2.7.10+xpu
IPEX commit: 0e47515e4
Build type: Release
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-17ubuntu1) 12.3.0
Clang version: N/A
IGC version: 2025.0.4 (2025.0.4.20241205)
CMake version: version 4.0.3
Libc version: glibc-2.39
Python version: 3.11.13 (main, Jun 4 2025, 08:57:30) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-140-generic-x86_64-with-glibc2.39
Is XPU available: True
DPCPP runtime: 2025.0
MKL version: 2025.0
GPU models and configuration onboard:
N/A
GPU models and configuration detected:
* [0] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
* [1] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
* [2] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
* [3] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
* [4] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
* [5] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
* [6] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
* [7] _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.3.30049+10', total_memory=65536MB, max_com
pute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
Driver version:
* intel_opencl: 24.39.31294.21-1032~24.04
* level_zero: 1.3.30049.10-950~24.04
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 112
On-line CPU(s) list: 0-111
Vendor ID: GenuineIntel
BIOS Vendor ID: Intel(R) Corporation
Model name: Intel(R) Xeon(R) Platinum 8480+
BIOS Model name: Intel(R) Xeon(R) Platinum 8480+ CPU @ 2.0GHz
BIOS CPU family: 179
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 56
Socket(s): 2
Stepping: 6
CPU(s) scaling MHz: 20%
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art ar
ch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe
popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority e
pt vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 x
saves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx51
2_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.3 MiB (112 instances)
L1i cache: 3.5 MiB (112 instances)
L2 cache: 224 MiB (112 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-55
NUMA node1 CPU(s): 56-111
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip] dpcpp-cpp-rt==2025.0.4
[pip] impi-devel==2021.14.1
[pip] impi-rt==2021.14.1
[pip] intel-cmplr-lib-rt==2025.0.4
[pip] intel-cmplr-lib-ur==2025.0.4
[pip] intel-cmplr-lic-rt==2025.0.4
[pip] intel_extension_for_pytorch==2.7.10+xpu
[pip] intel-opencl-rt==2025.0.4
[pip] intel-openmp==2025.0.4
[pip] intel-pti==0.10.1
[pip] intel-sycl-rt==2025.0.4
[pip] mkl==2025.0.1
[pip] mkl-dpcpp==2025.0.1
[pip] numpy==2.3.0
[pip] oneccl==2021.14.1
[pip] oneccl-bind-pt==2.7.0+xpu
[pip] oneccl-devel==2021.14.1
[pip] onemkl-sycl-blas==2025.0.1
[pip] onemkl-sycl-datafitting==2025.0.1
[pip] onemkl-sycl-dft==2025.0.1
[pip] onemkl-sycl-lapack==2025.0.1
[pip] onemkl-sycl-rng==2025.0.1
[pip] onemkl-sycl-sparse==2025.0.1
[pip] onemkl-sycl-stats==2025.0.1
[pip] onemkl-sycl-vm==2025.0.1
[pip] pytorch-triton-xpu==3.3.0
[pip] torch==2.7.0+xpu
[pip] torchaudio==2.7.0+xpu
[pip] torchvision==0.22.0+xpu
[pip] transformers==4.52.4