This repository was archived by the owner on Oct 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 353
PipeshardParallel + GPT2 example fails with compile error and segmentation fault #863
Copy link
Copy link
Closed
Description
Please describe the bug
I'm trying to use PipeshardParallel
for the GPT2 example in examples/gpt2
(20debbe) with Alpa v0.2.2 inside a Docker container. I'm on an RHEL node with four NVIDIA A40 GPUs.
- Submesh (1, 4) compilation errors with a failed check:
check failed: strategies->is_tuple || !strategies->leaf_vector.empty() %pad.38 = f16[8,512,2304]{2,1,0} pad(f16[8,512,768]{2,1,0} %reshape.1367, f16[] %constant.1168), padding=0_0x0_0x1536_0, metadata={op_name="parallelize(stage_0_1)/jit(main)/jit(merged)/jit(stage_0_1_compute2)/transpose(jvp(FlaxGPT2LMHeadModule))/transformer/h/11/attn/pad[padding_config=((0, 0, 0), (0, 0, 0), (1536, 0, 0))]" source_file="/opt/conda/envs/alpa/lib/python3.8/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py" source_line=211} does not have any valid strategies.
- Profiling somehow finishes anyway, but execution fails with a jaxlib-level attribute error and a segmentation fault:
AttributeError: module 'jaxlib.xla_extension' has no attribute 'nccl_create_communicators_no_stream
- See below for segfault stack trace.
Please describe the expected behavior
System information and environment
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): RHEL 8.7, Docker
- Python version: 3.8.16
- CUDA version: 11.3
- NCCL version: 2.9.6
- cupy version: cupy-cuda113 10.6.0
- GPU model and memory: Four NVIDIA A40, 46068MiB
- Alpa version: 0.2.2
- TensorFlow version: 2.11.0
- JAX version: 0.3.22
To Reproduce
Steps to reproduce the behavior:
- Build docker image with
docker/coreweave/run_alpa_infiniband.Dockerfile
. All following commands done inside container. git clone --recursive https://github.com/alpa-projects/alpa.git
cd alpa/examples/gpt2
- Edit
run_clm_flax.py
so that it usesPipeshardParallel
instead ofZero2Parallel
:method = alpa.PipeshardParallel( devices=None, num_micro_batches=training_args.num_micro_batches, # 16 in this case default_auto_sharding_option=None, pipeline_schedule="1f1b", layer_option=None, stage_option="auto", stage_input_shardings=None, )
pip install transformers datasets
(transformers 4.25.1, datasets 2.8.0)export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/
pip install tensorflow
mkdir norwegian-gpt2 && python train_tokenizer.py && python create_config.py
-
python3 run_clm_flax.py \ --output_dir="./norwegian-gpt2" \ --model_type="gpt2" \ --config_name="./norwegian-gpt2" \ --tokenizer_name="./norwegian-gpt2" \ --dataset_name="oscar" \ --dataset_config_name="unshuffled_deduplicated_no" \ --do_train \ --block_size="512" \ --per_device_train_batch_size="32" \ --num_micro_batches="16" \ --dtype="float16" \ --learning_rate="1e-3" --warmup_steps="1000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="20" \ --logging_steps="20" \ --save_steps="2500" \ --eval_steps="2500"
Full error output (tqdm disabled)
INFO:__main__:***** Running training *****
INFO:__main__: Num examples = 1966029
INFO:__main__: Num Epochs = 20
INFO:__main__: Batch size per device (w. accumulation) = 32
INFO:__main__: Global train batch size (w. parallel & distributed) = 128
INFO:__main__: Total optimization steps = 307180
Initial compilation. This might take some minutes...
-------------------- Automatic stage clustering --------------------
submesh_choices: ((1, 1), (1, 2), (1, 4))
- Profiling for submesh 2 (1, 4):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
(CompileWorker pid=73427) 2023-01-19 21:55:05.286251: F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/auto_sharding.cc:1465] Check failed: strategies->is_tuple || !strategies->leaf_vector.empty() %pad.38 = f16[8,512,2304]{2,1,0} pad(f16[8,512,768]{2,1,0} %reshape.1367, f16[] %constant.1168), padding=0_0x0_0x1536_0, metadata={op_name="parallelize(stage_0_1)/jit(main)/jit(merged)/jit(stage_0_1_compute2)/transpose(jvp(FlaxGPT2LMHeadModule))/transformer/h/11/attn/pad[padding_config=((0, 0, 0), (0, 0, 0), (1536, 0, 0))]" source_file="/opt/conda/envs/alpa/lib/python3.8/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py" source_line=211} does not have any valid strategies.
(CompileWorker pid=73427) *** SIGABRT received at time=1674165305 on cpu 46 ***
(CompileWorker pid=73427) PC: @ 0x7f698dbd000b (unknown) raise
(CompileWorker pid=73427) @ 0x7f698deed420 537164224 (unknown)
(CompileWorker pid=73427) @ 0x7f42c9fb207d 10592 xla::spmd::BuildStrategyAndCost()
(CompileWorker pid=73427) @ 0x7f42cb6ce3b4 2368 xla::spmd::AutoSharding::Run()
(CompileWorker pid=73427) @ 0x7f42cdf7f371 816 xla::HloPassPipeline::RunPassesInternal<>()
(CompileWorker pid=73427) @ 0x7f42cdf7ffc5 448 xla::HloPassPipeline::Run()
(CompileWorker pid=73427) @ 0x7f42ca52cf24 80 xla::HloPassInterface::Run()
(CompileWorker pid=73427) @ 0x7f42ca536391 4128 xla::spmd::RunAutoShardingPass()
(CompileWorker pid=73427) @ 0x7f42ca52215a 160 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
(CompileWorker pid=73427) @ 0x7f42ca392730 576 pybind11::cpp_function::dispatcher()
(CompileWorker pid=73427) @ 0x4e1172 (unknown) PyCFunction_Call
(CompileWorker pid=73427) @ 0x71a560 (unknown) (unknown)
(CompileWorker pid=73427) [2023-01-19 21:55:05,324 E 73427 73427] logging.cc:361: *** SIGABRT received at time=1674165305 on cpu 46 ***
(CompileWorker pid=73427) [2023-01-19 21:55:05,324 E 73427 73427] logging.cc:361: PC: @ 0x7f698dbd000b (unknown) raise
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f698deed420 537164224 (unknown)
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42c9fb207d 10592 xla::spmd::BuildStrategyAndCost()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42cb6ce3b4 2368 xla::spmd::AutoSharding::Run()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42cdf7f371 816 xla::HloPassPipeline::RunPassesInternal<>()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42cdf7ffc5 448 xla::HloPassPipeline::Run()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca52cf24 80 xla::HloPassInterface::Run()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca536391 4128 xla::spmd::RunAutoShardingPass()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca52215a 160 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca392730 576 pybind11::cpp_function::dispatcher()
(CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x4e1172 (unknown) PyCFunction_Call
(CompileWorker pid=73427) [2023-01-19 21:55:05,326 E 73427 73427] logging.cc:361: @ 0x71a560 (unknown) (unknown)
(CompileWorker pid=73427) Fatal Python error: Aborted
(CompileWorker pid=73427)
(CompileWorker pid=73427) Stack (most recent call first):
(CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 344 in run_auto_sharding_pass
(CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/pipeline_parallel/stage_profiling.py", line 161 in compile_stage_for_profiling
(CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 466 in _resume_span
(CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/function_manager.py", line 674 in actor_method_executor
(CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/worker.py", line 763 in main_loop
(CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/workers/default_worker.py", line 231 in <module>
2023-01-19 21:55:09,285 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffbde8ec2d3c9b71076befcba108000000 Worker ID: db6d86ae01244973e089c907ce3105bc69cd346479e7764b48feb453 Node ID: 207d08a537af2d27e0cc709647ab66e89e18b2763be4c8b5028126a3 Worker IP address: REDACTED_IP_ADDRESS Worker port: 10300 Worker PID: 73427 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
WARNING:alpa.pipeline_parallel.stage_profiling:A Compile worker died unexpectedly: The actor died unexpectedly before finishing this task.
class_name: CompileWorker
actor_id: bde8ec2d3c9b71076befcba108000000
pid: 73427
namespace: alpa_default_space
ip: REDACTED_IP_ADDRESS
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
- Profile all stages
cost[0, 1, 0]=0.036, max_n_succ_stage=4096, Mem: avail=39.475GB, peak=2.165GB, intermediate=0.000GB, init=0.348GB, as_config=((4, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 2]=0.082, max_n_succ_stage=4096, Mem: avail=39.475GB, peak=2.491GB, intermediate=0.000GB, init=0.348GB, as_config=((1, 4), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 3]=0.033, max_n_succ_stage=4096, Mem: avail=39.475GB, peak=1.991GB, intermediate=0.000GB, init=0.348GB, as_config=((4, 1), {})
Profiling for submesh 2 (1, 4) takes 44.25 seconds
Profiled costs are: [[[ inf inf inf inf]
[0.03560379 inf 0.08236101 0.03326474]]
[[ inf inf inf inf]
[ inf inf inf inf]]]
Profiled max_n_succ_stages are: [[[ -1 -1 -1 -1]
[4096 -1 4096 4096]]
[[ -1 -1 -1 -1]
[ -1 -1 -1 -1]]]
--------------------------------------------------
- Profiling for submesh 1 (1, 2):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
cost[0, 0, 0]=0.024, max_n_succ_stage=27, Mem: avail=39.475GB, peak=1.626GB, intermediate=1.295GB, init=0.494GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 0, 1]=0.029, max_n_succ_stage=23, Mem: avail=39.475GB, peak=1.708GB, intermediate=1.502GB, init=0.494GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 0, 2]=0.023, max_n_succ_stage=27, Mem: avail=39.475GB, peak=1.520GB, intermediate=1.295GB, init=0.494GB, as_config=((2, 1), {})
cost[0, 1, 0]=0.056, max_n_succ_stage=10, Mem: avail=39.475GB, peak=3.690GB, intermediate=3.061GB, init=0.695GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 1]=0.066, max_n_succ_stage=9, Mem: avail=39.475GB, peak=3.931GB, intermediate=3.460GB, init=0.696GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 0]=0.032, max_n_succ_stage=19, Mem: avail=39.475GB, peak=2.239GB, intermediate=1.766GB, init=0.489GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 2]=0.054, max_n_succ_stage=10, Mem: avail=39.475GB, peak=3.527GB, intermediate=3.025GB, init=0.695GB, as_config=((2, 1), {})
cost[1, 1, 1]=0.037, max_n_succ_stage=17, Mem: avail=39.475GB, peak=2.306GB, intermediate=1.958GB, init=0.490GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 2]=0.031, max_n_succ_stage=20, Mem: avail=39.475GB, peak=2.081GB, intermediate=1.730GB, init=0.489GB, as_config=((2, 1), {})
Profiling for submesh 1 (1, 2) takes 51.14 seconds
Profiled costs are: [[[0.02386636 0.02937219 0.0229995 inf]
[0.05573034 0.06621422 0.05409217 inf]]
[[ inf inf inf inf]
[0.03158776 0.03738411 0.03124457 inf]]]
Profiled max_n_succ_stages are: [[[27 23 27 -1]
[10 9 10 -1]]
[[-1 -1 -1 -1]
[19 17 20 -1]]]
--------------------------------------------------
- Profiling for submesh 0 (1, 1):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
cost[0, 0, 1]=0.040, max_n_succ_stage=13, Mem: avail=39.475GB, peak=2.900GB, intermediate=2.511GB, init=0.987GB, as_config=((1, 1), {})
cost[0, 0, 0]=0.040, max_n_succ_stage=13, Mem: avail=39.475GB, peak=2.900GB, intermediate=2.511GB, init=0.987GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 0]=0.056, max_n_succ_stage=9, Mem: avail=39.475GB, peak=4.118GB, intermediate=3.381GB, init=0.979GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 1]=0.056, max_n_succ_stage=9, Mem: avail=39.475GB, peak=4.118GB, intermediate=3.381GB, init=0.979GB, as_config=((1, 1), {})
cost[0, 1, 0]=0.095, max_n_succ_stage=4, Mem: avail=39.475GB, peak=6.835GB, intermediate=5.892GB, init=1.391GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 1]=0.095, max_n_succ_stage=4, Mem: avail=39.475GB, peak=6.835GB, intermediate=5.892GB, init=1.391GB, as_config=((1, 1), {})
Profiling for submesh 0 (1, 1) takes 27.48 seconds
Profiled costs are: [[[0.03979082 0.03967738 inf inf]
[0.09511036 0.0951425 inf inf]]
[[ inf inf inf inf]
[0.05555977 0.05556859 inf inf]]]
Profiled max_n_succ_stages are: [[[13 13 -1 -1]
[ 4 4 -1 -1]]
[[-1 -1 -1 -1]
[ 9 9 -1 -1]]]
--------------------------------------------------
Compute cost saved to: compute-cost-2023-01-19-21-57-01.npy
----------------------------------------------------------------------
Result forward_stage_layer_ids: [[0], [1]]
Result mesh_shapes: [(1, 2), (1, 2)]
Result logical_mesh_shapes: [(2, 1), (2, 1)]
Result autosharding_option_dicts: [{}, {}]
2023-01-19 21:57:17,350 ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=79095, ip=REDACTED_IP_ADDRESS, repr=<alpa.device_mesh.MeshHostWorker object at 0x7f19ec368bb0>)
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/device_mesh.py", line 411, in create_and_set_cross_mesh_communicators
comms = g.get_nccl_collective_communicator(devices, "xla")
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 478, in get_nccl_collective_communicator
return self._get_nccl_collective_communicator(key, devices, lib)
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 455, in _get_nccl_collective_communicator
comms = xla_extension.nccl_create_communicators_no_stream(
AttributeError: module 'jaxlib.xla_extension' has no attribute 'nccl_create_communicators_no_stream'
2023-01-19 21:57:17,528 ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=79094, ip=REDACTED_IP_ADDRESS, repr=<alpa.device_mesh.MeshHostWorker object at 0x7fe7b4724b80>)
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/device_mesh.py", line 411, in create_and_set_cross_mesh_communicators
comms = g.get_nccl_collective_communicator(devices, "xla")
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 478, in get_nccl_collective_communicator
return self._get_nccl_collective_communicator(key, devices, lib)
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 455, in _get_nccl_collective_communicator
comms = xla_extension.nccl_create_communicators_no_stream(
AttributeError: module 'jaxlib.xla_extension' has no attribute 'nccl_create_communicators_no_stream'
(MeshHostWorker pid=79094) [1674165445.753032] [REDACTED_HOST_NAME:79094:1] debug.c:1289 UCX WARN ucs_debug_disable_signal: signal 8 was not set in ucs
(MeshHostWorker pid=79094) [1674165445.753032] [REDACTED_HOST_NAME:79094:0] spinlock.c:29 UCX WARN ucs_recursive_spinlock_destroy() failed: busy
(MeshHostWorker pid=79095) [REDACTED_HOST_NAME:79095:0:79370] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x8)
(MeshHostWorker pid=79095) [REDACTED_HOST_NAME:79095:1:79368] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
(MeshHostWorker pid=79094) [REDACTED_HOST_NAME:79094:1:79400] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x8)
(MeshHostWorker pid=79094) [REDACTED_HOST_NAME:79094:0:79394] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
(MeshHostWorker pid=79095) [1674165445.755438] [REDACTED_HOST_NAME:79095:0] debug.c:1289 UCX WARN ucs_debug_disable_signal: signal 11 was not set in ucs
(MeshHostWorker pid=79095) [1674165445.755440] [REDACTED_HOST_NAME:79095:1] spinlock.c:29 UCX WARN ucs_recursive_spinlock_destroy() failed: busy
(MeshHostWorker pid=79095) ==== backtrace (tid: 79368) ====
(MeshHostWorker pid=79095) 0 0x0000000000014420 __funlockfile() ???:0
(MeshHostWorker pid=79095) 1 0x00000000021c6ca8 xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() :0
(MeshHostWorker pid=79095) 2 0x000000000215d53e xla::gpu::(anonymous namespace)::ExecuteThunks() gpu_executable.cc:0
(MeshHostWorker pid=79095) 3 0x000000000215ed40 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() :0
(MeshHostWorker pid=79095) 4 0x00000000021635f8 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() :0
(MeshHostWorker pid=79095) 5 0x000000000216425f xla::gpu::GpuExecutable::ExecuteAsyncOnStream() :0
(MeshHostWorker pid=79095) 6 0x000000000454c306 xla::Executable::ExecuteAsyncOnStreamWrapper() :0
(MeshHostWorker pid=79095) 7 0x00000000013183d0 xla::LocalExecutable::RunAsync() :0
(MeshHostWorker pid=79095) 8 0x0000000001318b40 xla::LocalExecutable::RunAsync() :0
(MeshHostWorker pid=79095) 9 0x00000000012df9ea xla::PjRtStreamExecutorExecutable::EnqueueExecution() :0
(MeshHostWorker pid=79095) 10 0x00000000012e0e21 xla::PjRtStreamExecutorExecutable::ExecuteHelper() :0
(MeshHostWorker pid=79095) 11 0x00000000012e3249 std::_Function_handler<void (), xla::PjRtStreamExecutorExecutable::Execute(absl::lts_20220623::Span<std::vector<xla::PjRtBuffer*, std::allocator<xla::PjRtBuffer*> > const>, xla::ExecuteOptions const&, std::optional<std::vector<xla::PjRtFuture<tsl::Status>, std::allocator<xla::PjRtFuture<tsl::Status> > > >&)::{lambda()#2}>::_M_invoke() pjrt_stream_executor_client.cc:0
(MeshHostWorker pid=79095) 12 0x00000000012ef468 xla::WorkerThread::WorkLoop() :0
(MeshHostWorker pid=79095) 13 0x00000000056a7005 tsl::(anonymous namespace)::PThread::ThreadFn() env.cc:0
(MeshHostWorker pid=79095) 14 0x0000000000008609 start_thread() ???:0
(MeshHostWorker pid=79095) 15 0x000000000011f133 clone() ???:0
(MeshHostWorker pid=79095) =================================
(MeshHostWorker pid=79095) *** SIGSEGV received at time=1674165446 on cpu 110 ***
(MeshHostWorker pid=79095) PC: @ 0x7ef35fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream()
(MeshHostWorker pid=79095) @ 0x7f1a218b6420 3728 (unknown)
(MeshHostWorker pid=79095) @ 0x7ef35fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks()
(MeshHostWorker pid=79095) @ 0x7ef35fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime()
(MeshHostWorker pid=79095) @ 0x7ef35fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
(MeshHostWorker pid=79095) @ 0x7ef35fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
(MeshHostWorker pid=79095) @ 0x7ef361f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper()
(MeshHostWorker pid=79095) @ 0x7ef35ed503d0 2432 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79095) @ 0x7ef35ed50b40 256 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79095) @ 0x7ef35ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution()
(MeshHostWorker pid=79095) @ 0x7ef35ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper()
(MeshHostWorker pid=79095) @ 0x7ef35ed1b249 240 std::_Function_handler<>::_M_invoke()
(MeshHostWorker pid=79095) @ 0x7ef35ed27468 208 xla::WorkerThread::WorkLoop()
(MeshHostWorker pid=79095) @ 0x7ef3630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn()
(MeshHostWorker pid=79095) @ 0x7f1a218aa609 (unknown) start_thread
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: *** SIGSEGV received at time=1674165446 on cpu 110 ***
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: PC: @ 0x7ef35fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7f1a218b6420 3728 (unknown)
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef361f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed503d0 2432 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed50b40 256 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed1b249 240 std::_Function_handler<>::_M_invoke()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed27468 208 xla::WorkerThread::WorkLoop()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef3630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn()
(MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7f1a218aa609 (unknown) start_thread
(MeshHostWorker pid=79095) Fatal Python error: Segmentation fault
(MeshHostWorker pid=79095)
(MeshHostWorker pid=79094) ==== backtrace (tid: 79394) ====
(MeshHostWorker pid=79094) 0 0x0000000000014420 __funlockfile() ???:0
(MeshHostWorker pid=79094) 1 0x00000000021c6ca8 xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() :0
(MeshHostWorker pid=79094) 2 0x000000000215d53e xla::gpu::(anonymous namespace)::ExecuteThunks() gpu_executable.cc:0
(MeshHostWorker pid=79094) 3 0x000000000215ed40 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() :0
(MeshHostWorker pid=79094) 4 0x00000000021635f8 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() :0
(MeshHostWorker pid=79094) 5 0x000000000216425f xla::gpu::GpuExecutable::ExecuteAsyncOnStream() :0
(MeshHostWorker pid=79094) 6 0x000000000454c306 xla::Executable::ExecuteAsyncOnStreamWrapper() :0
(MeshHostWorker pid=79094) 7 0x00000000013183d0 xla::LocalExecutable::RunAsync() :0
(MeshHostWorker pid=79094) 8 0x0000000001318b40 xla::LocalExecutable::RunAsync() :0
(MeshHostWorker pid=79094) 9 0x00000000012df9ea xla::PjRtStreamExecutorExecutable::EnqueueExecution() :0
(MeshHostWorker pid=79094) 10 0x00000000012e0e21 xla::PjRtStreamExecutorExecutable::ExecuteHelper() :0
(MeshHostWorker pid=79094) 11 0x00000000012e3249 std::_Function_handler<void (), xla::PjRtStreamExecutorExecutable::Execute(absl::lts_20220623::Span<std::vector<xla::PjRtBuffer*, std::allocator<xla::PjRtBuffer*> > const>, xla::ExecuteOptions const&, std::optional<std::vector<xla::PjRtFuture<tsl::Status>, std::allocator<xla::PjRtFuture<tsl::Status> > > >&)::{lambda()#2}>::_M_invoke() pjrt_stream_executor_client.cc:0
(MeshHostWorker pid=79094) 12 0x00000000012ef468 xla::WorkerThread::WorkLoop() :0
(MeshHostWorker pid=79094) 13 0x00000000056a7005 tsl::(anonymous namespace)::PThread::ThreadFn() env.cc:0
(MeshHostWorker pid=79094) 14 0x0000000000008609 start_thread() ???:0
(MeshHostWorker pid=79094) 15 0x000000000011f133 clone() ???:0
(MeshHostWorker pid=79094) =================================
(MeshHostWorker pid=79094) *** SIGSEGV received at time=1674165446 on cpu 93 ***
(MeshHostWorker pid=79094) PC: @ 0x7fc15fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream()
(MeshHostWorker pid=79094) @ 0x7fe81f685420 3728 (unknown)
(MeshHostWorker pid=79094) @ 0x7fc15fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks()
(MeshHostWorker pid=79094) @ 0x7fc15fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime()
(MeshHostWorker pid=79094) @ 0x7fc15fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
(MeshHostWorker pid=79094) @ 0x7fc15fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
(MeshHostWorker pid=79094) @ 0x7fc161f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper()
(MeshHostWorker pid=79094) @ 0x7fc15ed503d0 2432 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79094) @ 0x7fc15ed50b40 256 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79094) @ 0x7fc15ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution()
(MeshHostWorker pid=79094) @ 0x7fc15ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper()
(MeshHostWorker pid=79094) @ 0x7fc15ed1b249 240 std::_Function_handler<>::_M_invoke()
(MeshHostWorker pid=79094) @ 0x7fc15ed27468 208 xla::WorkerThread::WorkLoop()
(MeshHostWorker pid=79094) @ 0x7fc1630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn()
(MeshHostWorker pid=79094) @ 0x7fe81f679609 (unknown) start_thread
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: *** SIGSEGV received at time=1674165446 on cpu 93 ***
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: PC: @ 0x7fc15fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fe81f685420 3728 (unknown)
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc161f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed503d0 2432 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed50b40 256 xla::LocalExecutable::RunAsync()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed1b249 240 std::_Function_handler<>::_M_invoke()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed27468 208 xla::WorkerThread::WorkLoop()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc1630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn()
(MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fe81f679609 (unknown) start_thread
(MeshHostWorker pid=79094) Fatal Python error: Segmentation fault
(MeshHostWorker pid=79094)
2023-01-19 21:57:34,541 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff84fa01cc7ea38596ba9e9dbe08000000 Worker ID: bf4eb1f8861d60f9d12849d496d767d75ae70600d62ca47b9d4101bd Node ID: 207d08a537af2d27e0cc709647ab66e89e18b2763be4c8b5028126a3 Worker IP address: REDACTED_IP_ADDRESS Worker port: 10343 Worker PID: 79095 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
Traceback (most recent call last):
File "run_clm_flax.py", line 902, in <module>
main()
File "run_clm_flax.py", line 788, in main
executable.sync()
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/pipeline_parallel/pipeshard_executable.py", line 401, in sync
self.mesh_group.sync_workers()
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/device_mesh.py", line 2019, in sync_workers
ray.get([w.sync.remote() for w in all_workers])
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
return func(*args, **kwargs)
File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/worker.py", line 2291, in get
raise value
ray.exceptions.RayActorError: The actor died unexpectedly before finishing this task.
class_name: MeshHostWorker
actor_id: 84fa01cc7ea38596ba9e9dbe08000000
pid: 79095
namespace: alpa_default_space
ip: REDACTED_IP_ADDRESS
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
2023-01-19 21:57:34,838 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffa1ddb45e20cd16868664ad0208000000 Worker ID: 80c37f8712f00515006d806e5728697d4733b2dcde67767e582de8ad Node ID: 207d08a537af2d27e0cc709647ab66e89e18b2763be4c8b5028126a3 Worker IP address: REDACTED_IP_ADDRESS Worker port: 10342 Worker PID: 79094 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
As a side note, it would be great if there's a single Dockerfile to compile and run the Alpa HEAD commit.
Metadata
Metadata
Assignees
Labels
No labels