@@ -111,6 +111,8 @@ def test_pretrain_with_decentralized_pg(self, tmp_path):
111111 make_vocab_size_divisible_by = 128 ,
112112 vocab_size = None ,
113113 num_layers = 1 ,
114+ # Disable shared embeddings - not supported with decentralized PG
115+ share_embeddings_and_output_weights = False ,
114116 )
115117
116118 # Config Container with use_decentralized_pg=True
@@ -366,6 +368,8 @@ def test_pretrain_with_decentralized_pg_and_pp(self, tmp_path):
366368 make_vocab_size_divisible_by = 128 ,
367369 vocab_size = None ,
368370 num_layers = 2 , # Need at least 2 layers for PP=2
371+ # Disable shared embeddings - not supported with decentralized PG
372+ share_embeddings_and_output_weights = False ,
369373 )
370374
371375 # Config Container with use_decentralized_pg=True
@@ -496,6 +500,8 @@ def test_pretrain_with_decentralized_pg_and_cp(self, tmp_path):
496500 make_vocab_size_divisible_by = 128 ,
497501 vocab_size = None ,
498502 num_layers = 1 ,
503+ # Disable shared embeddings - not supported with decentralized PG
504+ share_embeddings_and_output_weights = False ,
499505 )
500506
501507 # Config Container with use_decentralized_pg=True
@@ -626,6 +632,8 @@ def test_pretrain_with_decentralized_pg_combined_parallelism(self, tmp_path):
626632 make_vocab_size_divisible_by = 128 ,
627633 vocab_size = None ,
628634 num_layers = 2 , # Need at least 2 layers for PP=2
635+ # Disable shared embeddings - not supported with decentralized PG
636+ share_embeddings_and_output_weights = False ,
629637 )
630638
631639 # Config Container with use_decentralized_pg=True
@@ -756,6 +764,8 @@ def test_pretrain_with_decentralized_pg_and_tp(self, tmp_path):
756764 make_vocab_size_divisible_by = 128 ,
757765 vocab_size = None ,
758766 num_layers = 1 ,
767+ # Disable shared embeddings - not supported with decentralized PG
768+ share_embeddings_and_output_weights = False ,
759769 )
760770
761771 # Config Container with use_decentralized_pg=True
0 commit comments