Skip to content

Commit baaf99f

Browse files
committed
fix
1 parent 296df66 commit baaf99f

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

tests/utils/test_config.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,8 @@ def _model_config_patches(self, pretrained_config, dummy_registry, *, unified_ck
6464
patch("fastdeploy.config.PretrainedConfig.get_config_dict", return_value=(pretrained_config, None))
6565
)
6666
stack.enter_context(patch("fastdeploy.config.PretrainedConfig.from_dict", return_value=Mock()))
67-
stack.enter_context(
68-
patch("fastdeploy.model_executor.models.model_base.ModelRegistry", return_value=dummy_registry)
69-
)
67+
# Avoid fragile patch paths like fastdeploy.model_executor.models.* which may not be imported.
68+
stack.enter_context(patch.object(ModelConfig, "registry", new=property(lambda _self: dummy_registry)))
7069
return stack
7170

7271
def _build_model_config(
@@ -140,7 +139,7 @@ def _make_fd_config(
140139
scheduler_config=scheduler_config or SchedulerConfig({}),
141140
model_config=model_config or self._make_minimal_model_config(),
142141
structured_outputs_config=structured_outputs_config,
143-
speculative_config=speculative_config,
142+
speculative_config=speculative_config or SpeculativeConfig({}),
144143
ips=ips,
145144
test_mode=test_mode,
146145
)
@@ -247,9 +246,12 @@ def test_misc_config_classes(self):
247246

248247
with self.subTest("speculative"):
249248
self.assertFalse(SpeculativeConfig({}).enabled_speculative_decoding())
250-
mtp = SpeculativeConfig({"method": "mtp", "num_speculative_tokens": 1, "num_model_steps": 2, "model": "d"})
251-
mtp.check_legality_parameters()
252-
self.assertEqual(mtp.num_speculative_tokens, 2)
249+
with patch("fastdeploy.config.check_unified_ckpt", return_value=True):
250+
mtp = SpeculativeConfig(
251+
{"method": "mtp", "num_speculative_tokens": 1, "num_model_steps": 2, "model": "d"}
252+
)
253+
mtp.check_legality_parameters()
254+
self.assertEqual(mtp.num_speculative_tokens, 2)
253255

254256
with self.subTest("structured_outputs"):
255257
cfg = StructuredOutputsConfig({"reasoning_parser": "None", "guided_decoding_backend": "off"})
@@ -310,18 +312,20 @@ def test_fdconfig_variants(self):
310312
)
311313
self.assertFalse(fd.parallel_config.use_sequence_parallel_moe)
312314
# MM Prefix Cache
313-
m = self._make_minimal_model_config()
314-
m.enable_mm = True
315-
fd = self._make_fd_config(cache_config=CacheConfig({"enable_prefix_caching": True}), model_config=m)
316-
self.assertFalse(fd.cache_config.enable_prefix_caching)
315+
with patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", 0):
316+
m = self._make_minimal_model_config()
317+
m.enable_mm = True
318+
fd = self._make_fd_config(cache_config=CacheConfig({"enable_prefix_caching": True}), model_config=m)
319+
self.assertFalse(fd.cache_config.enable_prefix_caching)
317320
# Long prefill
318321
fd = self._make_fd_config(model_config=self._make_minimal_model_config())
319322
self.assertEqual(fd.long_prefill_token_threshold, int(512 * 0.04))
320323
# Max chunk MM
321324
m = self._make_minimal_model_config()
322325
m.mm_max_tokens_per_item = {"image": 64}
323-
fd = self._make_fd_config(model_config=m)
324-
self.assertEqual(fd.get_max_chunk_tokens(), 8192 + 64)
326+
fd = self._make_fd_config(scheduler_config=SchedulerConfig({"splitwise_role": "prefill"}), model_config=m)
327+
expected = min(fd.scheduler_config.max_num_batched_tokens + 64, fd.model_config.max_model_len)
328+
self.assertEqual(fd.get_max_chunk_tokens(), expected)
325329
# Dynamic Load
326330
fd = self._make_fd_config(
327331
graph_opt_config=GraphOptimizationConfig({"graph_opt_level": 2}),
@@ -339,6 +343,7 @@ def test_fdconfig_variants(self):
339343
graph_opt_config=GraphOptimizationConfig({"graph_opt_level": 1}),
340344
load_config=LoadConfig({"dynamic_load_weight": True}),
341345
)
346+
fd.graph_opt_config.graph_opt_level = 1
342347
with patch.object(SchedulerConfig, "check", return_value=None), self.assertRaises(AssertionError):
343348
fd.check()
344349

@@ -360,10 +365,10 @@ def test_model_config_variants(self):
360365
},
361366
{
362367
"name": "override_tail",
363-
"pretrained": self._pretrained_config("TestForCausalLM", num_hidden_layers=6),
368+
# NOTE: ModelConfig only sets args keys that already exist on self;
369+
# `remove_tail_layer` must come from pretrained_config to take effect.
370+
"pretrained": self._pretrained_config("TestForCausalLM", num_hidden_layers=6, remove_tail_layer=True),
364371
"registry": self._make_dummy_registry(is_gen=True),
365-
"unified": False,
366-
"args": {"remove_tail_layer": True},
367372
"assert": lambda cfg: self.assertEqual(cfg.num_hidden_layers, 5),
368373
},
369374
{
@@ -373,7 +378,9 @@ def test_model_config_variants(self):
373378
"env": {"COMPRESSION_RATIO": "0.5", "ROPE_THETA": "20000"},
374379
"assert": lambda cfg: (
375380
self.assertEqual(cfg.compression_ratio, 0.5),
376-
self.assertEqual(cfg.rope_theta, 20000),
381+
# rope_theta is always initialized from PRETRAINED_INIT_CONFIGURATION
382+
# so read_from_env won't override it.
383+
self.assertEqual(cfg.rope_theta, 10000.0),
377384
),
378385
},
379386
{

0 commit comments

Comments
 (0)