Skip to content

NeuronXCC Error in Compilation for Pytorch "all_to_all" operation: LLVM ERROR: allToAllOp does not have expected attributesΒ #1142

@aleks-tu

Description

@aleks-tu

Hello everyone,

I wanted to test, whether distributed Pytorch operations between Neuron-Cores are supported on the trn1.2xlarge instance type. I tried testing the all_to_all operation provided by the torch_xla package by using a slightly altered version of the all_to_all_test given in the Pytorch XLA Github repository. When executing the Script, i get an LLVM-Error, stating

LLVM ERROR: allToAllOp does not have expected attributes

Is this not supported? Thank you for your help.

Python-Script:

import sys
import os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  if xm.xla_device_hw(device) == 'NEURON':
    slots_per_device = 4
    size = slots_per_device * xm.xrt_world_size()
    ordinal = xm.get_ordinal()
    value = torch.tensor([ordinal] * size, dtype=torch.int32, device=device)
    result_tensor = xm.all_to_all(
        value,
        split_dimension=0,
        concat_dimension=0,
        split_count=xm.xrt_world_size())

    result = result_tensor.cpu().tolist()
    for i in range(0, xm.xrt_world_size()):
      expected = [i] * slots_per_device
      if expected != result[i * slots_per_device:(i + 1) * slots_per_device]:
        print(
            'Wrong result from core {}: {}'.format(i, result), file=sys.stderr)
        sys.exit(1)
  else:
    print(
        'Default device {} is not a NEURON device'.format(device), file=sys.stderr)

if __name__ == '__main__':
    os.environ["NEURONCORE_NUM_DEVICES"] = "2"
    xmp.spawn(_mp_fn, args=())

Steps to reproduce:

  1. Start trn1.2xlarge Instance using image-id "ami-080b4a9b6e048125e" (Deep Learning AMI Neuron (Amazon Linux 2023) 20250115)
  2. source /opt/aws_neuronx_venv_pytorch_2_5/bin/activate (PyTorch 2.5 Torch NeuronX, NxD Core Environment)
  3. python3 test_all_to_all.py

Python Output:

(aws_neuronx_venv_pytorch_2_5) [ec2-user@ ~]$ python3 test_all_to_all.py 
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
WARNING:root:Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
2025-04-13 16:29:08.000978:  3253  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/ec2-user/neuroncc_compile_workdir/b11815d6-2409-4884-a408-cc6d3a174298/model.MODULE_15168660040147367519+e30acd3a.hlo_module.pb --output /tmp/ec2-user/neuroncc_compile_workdir/b11815d6-2409-4884-a408-cc6d3a174298/model.MODULE_15168660040147367519+e30acd3a.neff --target=trn1 --verbose=35
.
2025-04-13 16:29:26.000523:  3253  ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--framework=XLA', '/tmp/ec2-user/neuroncc_compile_workdir/b11815d6-2409-4884-a408-cc6d3a174298/model.MODULE_15168660040147367519+e30acd3a.hlo_module.pb', '--output', '/tmp/ec2-user/neuroncc_compile_workdir/b11815d6-2409-4884-a408-cc6d3a174298/model.MODULE_15168660040147367519+e30acd3a.neff', '--target=trn1', '--verbose=35']: Process Process-1:
Traceback (most recent call last):
  File "neuronxcc/driver/CommandDriver.py", line 345, in neuronxcc.driver.CommandDriver.CommandDriver.run_subcommand
  File "neuronxcc/driver/commands/CompileCommand.py", line 1353, in neuronxcc.driver.commands.CompileCommand.CompileCommand.run
  File "neuronxcc/driver/commands/CompileCommand.py", line 1304, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
  File "neuronxcc/driver/commands/CompileCommand.py", line 1324, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
  File "neuronxcc/driver/commands/CompileCommand.py", line 1327, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
  File "neuronxcc/driver/Job.py", line 344, in neuronxcc.driver.Job.SingleInputJob.run
  File "neuronxcc/driver/Job.py", line 370, in neuronxcc.driver.Job.SingleInputJob.runOnState
  File "neuronxcc/driver/Pipeline.py", line 30, in neuronxcc.driver.Pipeline.Pipeline.runSingleInput
  File "neuronxcc/driver/Job.py", line 344, in neuronxcc.driver.Job.SingleInputJob.run
  File "neuronxcc/driver/Job.py", line 370, in neuronxcc.driver.Job.SingleInputJob.runOnState
  File "neuronxcc/driver/jobs/Frontend.py", line 454, in neuronxcc.driver.jobs.Frontend.Frontend.runSingleInput
  File "neuronxcc/driver/jobs/Frontend.py", line 218, in neuronxcc.driver.jobs.Frontend.Frontend.runXLAFrontend
  File "neuronxcc/driver/jobs/Frontend.py", line 190, in neuronxcc.driver.jobs.Frontend.Frontend.runHlo2Tensorizer
neuronxcc.driver.Exceptions.CompilerInvalidInputException: ERROR: Failed command  /opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/neuronxcc/starfish/bin/hlo2penguin --input /tmp/ec2-user/neuroncc_compile_workdir/b11815d6-2409-4884-a408-cc6d3a174298/model.MODULE_15168660040147367519+e30acd3a.hlo_module.pb --out-dir ./ --output penguin.py --layers-per-module=1 --emit-tensor-level-dropout-ops --emit-tensor-level-rng-ops
------------ 
Reported stdout: 
DEBUG: needsModular? No. macCnt 0
INFO: Switching to single-module compile. PrePartitionPipe skipped.
INFO: Found memory bound graph
Replaced 0 dropout sequences with OffloadedDropout
INFO: HloMacCount has found 0
INFO: Traffic has found 102
INFO: AIF 0
HLO Ops used in computation: add all-to-all broadcast concatenate convert get-tuple-element parameter slice tuple 
Invoking RemoveOptimizationBarriers pass
LLVM ERROR: allToAllOp does not have expected attributes

------------ 
Reported stderr: 
None
------------                       
Import of the HLO graph into the Neuron Compiler has failed.                       
This may be caused by unsupported operators or an internal compiler error.                       
More details can be found in the error message(s) above.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib64/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib64/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "neuronxcc/driver/CommandDriver.py", line 352, in neuronxcc.driver.CommandDriver.CommandDriver.run_subcommand_in_process
  File "neuronxcc/driver/CommandDriver.py", line 347, in neuronxcc.driver.CommandDriver.CommandDriver.run_subcommand
  File "neuronxcc/driver/CommandDriver.py", line 111, in neuronxcc.driver.CommandDriver.handleError
  File "neuronxcc/driver/GlobalState.py", line 102, in neuronxcc.driver.GlobalState.FinalizeGlobalState
  File "neuronxcc/driver/GlobalState.py", line 82, in neuronxcc.driver.GlobalState._GlobalStateImpl.shutdown
  File "/usr/lib64/python3.9/shutil.py", line 724, in rmtree
    onerror(os.lstat, path, sys.exc_info())
  File "/usr/lib64/python3.9/shutil.py", line 722, in rmtree
    orig_st = os.lstat(path)
FileNotFoundError: [Errno 2] No such file or directory: '/tmp/ec2-user/neuroncc_compile_workdir/b11815d6-2409-4884-a408-cc6d3a174298/neuronxcc-gjfc05yo'

2025-04-13 16:29:26.000523:  3253  ERROR ||NEURON_CC_WRAPPER||: Compilation failed for /tmp/ec2-user/neuroncc_compile_workdir/b11815d6-2409-4884-a408-cc6d3a174298/model.MODULE_15168660040147367519+e30acd3a.hlo_module.pb after 0 retries.
2025-04-13 16:29:28.000994:  3256  ERROR ||NEURON_CC_WRAPPER||: Got a cached failed neff at /var/tmp/neuron-compile-cache/neuronxcc-2.16.372.0+4a9b2326/MODULE_15168660040147367519+e30acd3a/model.neff. Will skip compilation, please set --retry_failed_compilation for recompilation: 
 None.
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib64/python3.9/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib64/python3.9/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib64/python3.9/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device
    replica_results = list(
  File "/usr/lib64/python3.9/concurrent/futures/_base.py", line 609, in result_iterator
    yield fs.pop().result()
  File "/usr/lib64/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/usr/lib64/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
  File "/usr/lib64/python3.9/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn
    return fn()
  File "/opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/torch_xla/_internal/pjrt.py", line 185, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/home/ec2-user/test_all_to_all.py", line 20, in _mp_fn
    result = result_tensor.cpu().tolist()
RuntimeError: Bad StatusOr access: INTERNAL: RunNeuronCCImpl: error condition error != 0: <class 'subprocess.CalledProcessError'>: Command '' died with <Signals.SIGHUP: 1>.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ec2-user/test_all_to_all.py", line 33, in <module>
    xmp.spawn(_mp_fn, args=())
  File "/opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 37, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/torch_xla/_internal/pjrt.py", line 209, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/opt/aws_neuronx_venv_pytorch_2_5/lib64/python3.9/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib64/python3.9/concurrent/futures/process.py", line 562, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib64/python3.9/concurrent/futures/_base.py", line 609, in result_iterator
    yield fs.pop().result()
  File "/usr/lib64/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/usr/lib64/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: INTERNAL: RunNeuronCCImpl: error condition error != 0: <class 'subprocess.CalledProcessError'>: Command '' died with <Signals.SIGHUP: 1>.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions