Skip to content

Commit 8683de8

Browse files
authored
[Fix] make BaseModel._resolve_compile_cfg works. (#1481)
* [Fix] BaseModel._resolve_compile_cfg is not correct when FSDPConfig.torch_compile is the default value(False). * [CI] fix ci compile option usage (except test_resolve_compile) * [Fix] add warning for the deprecation of FSDPConfig.torch_compile
1 parent bc35383 commit 8683de8

File tree

10 files changed

+18
-24
lines changed

10 files changed

+18
-24
lines changed

tests/engine/test_dense_train_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def test_dense_engine_train(self, device, tp_size, sp_size):
4040
optim_cfg: AdamWConfig = AdamWConfig()
4141
lr_cfg: LRConfig = LRConfig()
4242
fsdp_cfg: FSDPConfig = FSDPConfig(
43-
torch_compile=True,
4443
cpu_offload=False,
4544
tp_size=tp_size,
4645
# hsdp_sharding_size=hsdp_sharding_size,
@@ -125,7 +124,6 @@ def test_save_and_load(self, device, tp_size, hsdp_sharding_size):
125124
moe_cfg = Qwen3Dense8BConfig()
126125
optim_cfg: AdamWConfig = AdamWConfig()
127126
fsdp_cfg: FSDPConfig = FSDPConfig(
128-
torch_compile=True,
129127
cpu_offload=False,
130128
tp_size=tp_size,
131129
hsdp_sharding_size=hsdp_sharding_size,

tests/engine/test_moe_train_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def test_moe_engine_train(self, device, ep_size, sp_size):
4848
ep_size=ep_size,
4949
balancing_loss_cfg=BalancingLossConfig(),
5050
z_loss_cfg=ZLossConfig(),
51+
compile_cfg=False,
5152
)
5253
optim_cfg: AdamWConfig = AdamWConfig()
5354
lr_cfg: LRConfig = LRConfig()
5455
fsdp_cfg: FSDPConfig = FSDPConfig(
55-
torch_compile=False,
5656
cpu_offload=False,
5757
ep_size=ep_size,
5858
# hsdp_sharding_size=hsdp_sharding_size,
@@ -129,11 +129,11 @@ def test_moe_engine_train_freeze_routers(self, device, ep_size, sp_size):
129129
balancing_loss_cfg=BalancingLossConfig(),
130130
z_loss_cfg=ZLossConfig(),
131131
freeze_routers=True,
132+
compile_cfg=False,
132133
)
133134
optim_cfg: AdamWConfig = AdamWConfig()
134135
lr_cfg: LRConfig = LRConfig()
135136
fsdp_cfg: FSDPConfig = FSDPConfig(
136-
torch_compile=False,
137137
cpu_offload=False,
138138
ep_size=ep_size,
139139
# hsdp_sharding_size=hsdp_sharding_size,
@@ -232,10 +232,10 @@ def test_save_and_load(self, device, ep_size, hsdp_sharding_size):
232232
ep_size=ep_size,
233233
balancing_loss_cfg=BalancingLossConfig(),
234234
z_loss_cfg=ZLossConfig(),
235+
compile_cfg=False,
235236
)
236237
optim_cfg: AdamWConfig = AdamWConfig()
237238
fsdp_cfg: FSDPConfig = FSDPConfig(
238-
torch_compile=False,
239239
cpu_offload=False,
240240
ep_size=ep_size,
241241
hsdp_sharding_size=hsdp_sharding_size,
@@ -447,12 +447,12 @@ def create_engine_from_hf(load_from: Path, dispatcher: str | None, ep_size: int,
447447
moe_cfg : Qwen3MoEConfig = get_model_config_from_hf(load_from)
448448
moe_cfg.dispatcher = dispatcher
449449
moe_cfg.ep_size = ep_size
450+
moe_cfg.compile_cfg = False
450451
if tiny:
451452
moe_cfg.num_hidden_layers = 2
452453

453454
optim_cfg: AdamWConfig = AdamWConfig()
454455
fsdp_cfg: FSDPConfig = FSDPConfig(
455-
torch_compile=False,
456456
cpu_offload=False,
457457
ep_size=ep_size,
458458
)

tests/engine/test_moe_train_engine_float8.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def test_tile_wise_fp8(self, device, ep_size, hsdp_sharding_size):
4949
optim_cfg: AdamWConfig = AdamWConfig()
5050
lr_cfg: LRConfig = LRConfig()
5151
fsdp_cfg: FSDPConfig = FSDPConfig(
52-
torch_compile=True,
5352
cpu_offload=False,
5453
ep_size=ep_size,
5554
# hsdp_sharding_size=8,
@@ -130,7 +129,6 @@ def test_tensor_wise_fp8(self, device, ep_size, hsdp_sharding_size):
130129
optim_cfg: AdamWConfig = AdamWConfig()
131130
lr_cfg: LRConfig = LRConfig()
132131
fsdp_cfg: FSDPConfig = FSDPConfig(
133-
torch_compile=True,
134132
cpu_offload=False,
135133
ep_size=ep_size,
136134
# hsdp_sharding_size=hsdp_sharding_size,
@@ -217,7 +215,6 @@ def test_save_and_load(self, device, ep_size, hsdp_sharding_size):
217215
optim_cfg: AdamWConfig = AdamWConfig()
218216
lr_cfg: LRConfig = LRConfig()
219217
fsdp_cfg: FSDPConfig = FSDPConfig(
220-
torch_compile=True,
221218
cpu_offload=False,
222219
ep_size=ep_size,
223220
# hsdp_sharding_size=hsdp_sharding_size,
@@ -323,7 +320,6 @@ def test_save_and_load1(self, device, ep_size, hsdp_sharding_size):
323320
)
324321
optim_cfg: AdamWConfig = AdamWConfig()
325322
fsdp_cfg: FSDPConfig = FSDPConfig(
326-
torch_compile=True,
327323
cpu_offload=False,
328324
ep_size=ep_size,
329325
hsdp_sharding_size=hsdp_sharding_size,

tests/model/test_qwen3_tile_embedding.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def test_tie_embedding(self, device, tp_size):
4545
optim_cfg: AdamWConfig = AdamWConfig()
4646
lr_cfg: LRConfig = LRConfig(lr_min=1e-3)
4747
fsdp_cfg: FSDPConfig = FSDPConfig(
48-
torch_compile=True,
4948
cpu_offload=False,
5049
tp_size=tp_size
5150
)
@@ -114,7 +113,6 @@ def test_qwen3vl_tie_embedding(self, device, tp_size):
114113
optim_cfg: AdamWConfig = AdamWConfig()
115114
lr_cfg: LRConfig = LRConfig(lr_min=1e-3)
116115
fsdp_cfg: FSDPConfig = FSDPConfig(
117-
torch_compile=True,
118116
cpu_offload=False,
119117
tp_size=tp_size
120118
)

tests/model/test_qwen3_vl.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,10 @@ def test_fsdp_qwen3_run(self, device, sp_size, compile, tol):
206206
patch_hf_rms_norm(hf_model)
207207

208208
with torch.device("meta"):
209-
model_cfg = Qwen3VLDense4BConfig()
210-
if compile is False:
211-
model_cfg.compile_cfg = False
209+
model_cfg = Qwen3VLDense4BConfig(compile_cfg=compile)
212210
qwen3vl_model = model_cfg.build().to(torch.bfloat16)
213211

214-
fsdp_config = FSDPConfig(
215-
cpu_offload=False,
216-
torch_compile=compile
217-
)
218-
212+
fsdp_config = FSDPConfig(cpu_offload=False)
219213
fsdp_mesh = init_world_mesh()
220214
qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh
221215
qwen3vl_model.vision_tower.fsdp_config = fsdp_config

tests/ray/test_rl_train_with_sft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def build_train_controller(self):
8080
model_cfg = Qwen3Dense8BConfig()
8181
optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False)
8282
fsdp_cfg: FSDPConfig = FSDPConfig(
83-
torch_compile=True,
8483
cpu_offload=False,
8584
ep_size=1,
8685
)

tests/ray/test_rl_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def tearDownClass(cls):
4747

4848
def init_traine_worker_config(self, train_optimizer_steps, pack_max_length):
4949
model_cfg = get_model_config_from_hf(Path(MODEL_PATH))
50+
model_cfg.compile_cfg = False
5051
optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False)
5152
loss_cfg = GRPOLossConfig(
5253
policy_loss_cfg=dict(
@@ -65,7 +66,7 @@ def init_traine_worker_config(self, train_optimizer_steps, pack_max_length):
6566
chunk_size=512,
6667
)
6768
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
68-
fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1)
69+
fsdp_cfg = FSDPConfig(cpu_offload=False, ep_size=1)
6970
train_worker_cfg: WorkerConfig = WorkerConfig(
7071
model_cfg=model_cfg,
7172
load_from=MODEL_PATH,

tests/train/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def prepare(self):
408408

409409
self.optim_cfg = AdamWConfig(lr=0.1, weight_decay=0.1)
410410
self.lr_cfg = LRConfig(lr_type="cosine", lr_min=0.001, warmup_ratio=0.03)
411-
self.fsdp_cfg = FSDPConfig(torch_compile=True)
411+
self.fsdp_cfg = FSDPConfig()
412412
temp_dir = tempfile.TemporaryDirectory()
413413
if dist.get_rank() == 0:
414414
temp_dir = [temp_dir.name]

xtuner/v1/config/fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class FSDPConfig(BaseModel):
2222
# TODO: (caoweihan) Convert `torch.dtype` to `Annotated` for compatibility with cyclopts
2323
param_dtype: Annotated[torch.dtype, Parameter(help="Data type for model parameters")] = torch.bfloat16
2424
reduce_dtype: Annotated[torch.dtype, Parameter(help="Data type for reduction operations")] = torch.bfloat16
25-
torch_compile: Annotated[bool, Parameter(help="Enable model compilation for faster inference")] = False
25+
# TODO: deprecate `torch_compile` in favor of `compile_cfg` in XTunerBaseModelConfig
26+
torch_compile: Annotated[bool, Parameter(help="Enable model compilation for faster inference")] = True
2627
mesh_prefix: Annotated[str, Parameter(help="Prefix for device mesh configuration in distributed training")] = (
2728
"default"
2829
)

xtuner/v1/train/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,5 +1845,12 @@ def _print_training_config(self):
18451845
logger.info(f"Training config: {config_str}")
18461846

18471847
def _resolve_deprecate_compile_cfg(self, model_cfg: XTunerBaseModelConfig, fsdp_cfg: FSDPConfig):
1848+
if self.rank == 0:
1849+
logger.warning(
1850+
"FSDPConfig.torch_compile is deprecated, and will be removed in version 1.1.0. "
1851+
"Please use XTunerBaseModelConfig.compile_cfg to control whether to use torch.compile for the model"
1852+
)
18481853
if not fsdp_cfg.torch_compile:
1854+
if self.rank == 0:
1855+
logger.warning("FSDPConfig.torch_compile is set to False, setting model_cfg.compile_cfg to False.")
18491856
model_cfg.compile_cfg = False

0 commit comments

Comments
 (0)