-
Notifications
You must be signed in to change notification settings - Fork 59
Description
I have been successfully running JetStream with the MaxText engine on v4-8. However, all my attempts at running them on v5 seem to fail.
I am starting the server with on a v5e-16:
python -u -m MaxText.maxengine_server MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=gs://myserver/myunscanned-gemma2-9b/checkpoints/0/items max_prefill_predict_length=512 max_target_length=1024 model_name=gemma2-9b ici_fsdp_parallelism=4 ici_autoregressive_parallelism=1 ici_tensor_parallelism=4 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1
Connecting with a receiver.py connecting to localhost:9000, it hangs forever. Using my own gRPC client, I am getting:
a server crash with AssertionError: Batch dimension should be shardable among the devices in data and fsdp axis