|
24 | 24 | GpuSetting, |
25 | 25 | GpuSettingDict, |
26 | 26 | GpuSplitConfigDict, |
| 27 | + KvConfigFieldDict, |
27 | 28 | KvConfigStackDict, |
28 | 29 | LlmLoadModelConfig, |
29 | 30 | LlmLoadModelConfigDict, |
@@ -469,44 +470,87 @@ def test_parse_server_config_load_llm() -> None: |
469 | 470 | assert parse_server_config(server_config) == expected_client_config |
470 | 471 |
|
471 | 472 |
|
472 | | -def _other_gpu_split_strategies() -> Iterator[LlmSplitStrategy]: |
| 473 | +def _gpu_split_strategies() -> Iterator[LlmSplitStrategy]: |
473 | 474 | # Ensure all GPU split strategies are checked (these aren't simple structural transforms, |
474 | 475 | # so the default test case doesn't provide adequate test coverage ) |
475 | 476 | for split_strategy in get_args(LlmSplitStrategy): |
476 | | - if split_strategy == GPU_CONFIG["splitStrategy"]: |
477 | | - continue |
478 | 477 | yield split_strategy |
479 | 478 |
|
480 | 479 |
|
481 | | -def _find_config_field(stack_dict: KvConfigStackDict, key: str) -> Any: |
482 | | - for field in stack_dict["layers"][0]["config"]["fields"]: |
483 | | - if field["key"] == key: |
484 | | - return field["value"] |
| 480 | +def _find_config_field( |
| 481 | + stack_dict: KvConfigStackDict, key: str |
| 482 | +) -> tuple[int, KvConfigFieldDict]: |
| 483 | + for enumerated_field in enumerate(stack_dict["layers"][0]["config"]["fields"]): |
| 484 | + if enumerated_field[1]["key"] == key: |
| 485 | + return enumerated_field |
485 | 486 | raise KeyError(key) |
486 | 487 |
|
487 | 488 |
|
488 | | -@pytest.mark.parametrize("split_strategy", _other_gpu_split_strategies()) |
489 | | -def test_other_gpu_split_strategy_config(split_strategy: LlmSplitStrategy) -> None: |
| 489 | +def _del_config_field(stack_dict: KvConfigStackDict, key: str) -> None: |
| 490 | + field_index = _find_config_field(stack_dict, key)[0] |
| 491 | + field_list = cast(list[Any], stack_dict["layers"][0]["config"]["fields"]) |
| 492 | + del field_list[field_index] |
| 493 | + |
| 494 | + |
| 495 | +def _find_config_value(stack_dict: KvConfigStackDict, key: str) -> Any: |
| 496 | + return _find_config_field(stack_dict, key)[1]["value"] |
| 497 | + |
| 498 | + |
| 499 | +def _append_invalid_config_field(stack_dict: KvConfigStackDict, key: str) -> None: |
| 500 | + field_list = cast(list[Any], stack_dict["layers"][0]["config"]["fields"]) |
| 501 | + field_list.append({"key": key}) |
| 502 | + |
| 503 | + |
| 504 | +@pytest.mark.parametrize("split_strategy", _gpu_split_strategies()) |
| 505 | +def test_gpu_split_strategy_config(split_strategy: LlmSplitStrategy) -> None: |
| 506 | + # GPU config mapping is complex enough to need some additional testing |
| 507 | + input_camelCase = deepcopy(LOAD_CONFIG_LLM) |
| 508 | + input_snake_case = deepcopy(SC_LOAD_CONFIG_LLM) |
| 509 | + gpu_camelCase: GpuSettingDict = cast(Any, input_camelCase["gpu"]) |
| 510 | + gpu_snake_case: dict[str, Any] = cast(Any, input_snake_case["gpu"]) |
490 | 511 | expected_stack = deepcopy(EXPECTED_KV_STACK_LOAD_LLM) |
491 | | - if split_strategy == "favorMainGpu": |
492 | | - expected_split_config: GpuSplitConfigDict = _find_config_field( |
| 512 | + expected_server_config = expected_stack["layers"][0]["config"] |
| 513 | + gpu_camelCase["splitStrategy"] = gpu_snake_case["split_strategy"] = split_strategy |
| 514 | + if split_strategy == GPU_CONFIG["splitStrategy"]: |
| 515 | + assert split_strategy == "evenly", ( |
| 516 | + "Unexpected default LLM GPU offload split strategy (missing test case update?)" |
| 517 | + ) |
| 518 | + # There is no main GPU when the split strategy is even across GPUs |
| 519 | + del gpu_camelCase["mainGpu"] |
| 520 | + del gpu_snake_case["main_gpu"] |
| 521 | + elif split_strategy == "favorMainGpu": |
| 522 | + expected_split_config: GpuSplitConfigDict = _find_config_value( |
493 | 523 | expected_stack, "load.gpuSplitConfig" |
494 | 524 | ) |
495 | 525 | expected_split_config["strategy"] = "priorityOrder" |
496 | 526 | main_gpu = GPU_CONFIG["mainGpu"] |
497 | 527 | assert main_gpu is not None |
498 | 528 | expected_split_config["priority"] = [main_gpu] |
499 | 529 | else: |
500 | | - assert split_strategy is None, "Unknown LLM GPU offset split strategy" |
501 | | - input_camelCase = deepcopy(LOAD_CONFIG_LLM) |
502 | | - input_snake_case = deepcopy(SC_LOAD_CONFIG_LLM) |
503 | | - gpu_camelCase: GpuSettingDict = cast(Any, input_camelCase["gpu"]) |
504 | | - gpu_snake_case: dict[str, Any] = cast(Any, input_snake_case["gpu"]) |
505 | | - gpu_camelCase["splitStrategy"] = gpu_snake_case["split_strategy"] = split_strategy |
| 530 | + assert split_strategy is None, ( |
| 531 | + "Unknown LLM GPU offload split strategy (missing test case update?)" |
| 532 | + ) |
| 533 | + # Check given GPU config maps as expected in both directions |
| 534 | + kv_stack = load_config_to_kv_config_stack(input_camelCase, LlmLoadModelConfig) |
| 535 | + assert kv_stack.to_dict() == expected_stack |
| 536 | + kv_stack = load_config_to_kv_config_stack(input_snake_case, LlmLoadModelConfig) |
| 537 | + assert kv_stack.to_dict() == expected_stack |
| 538 | + assert parse_server_config(expected_server_config) == input_camelCase |
| 539 | + # Check a malformed ratio field is tolerated |
| 540 | + gpu_camelCase["ratio"] = gpu_snake_case["ratio"] = None |
| 541 | + _del_config_field(expected_stack, "llm.load.llama.acceleration.offloadRatio") |
| 542 | + _append_invalid_config_field( |
| 543 | + expected_stack, "llm.load.llama.acceleration.offloadRatio" |
| 544 | + ) |
| 545 | + assert parse_server_config(expected_server_config) == input_camelCase |
| 546 | + # Check mapping works if no explicit offload ratio is specified |
| 547 | + _del_config_field(expected_stack, "llm.load.llama.acceleration.offloadRatio") |
506 | 548 | kv_stack = load_config_to_kv_config_stack(input_camelCase, LlmLoadModelConfig) |
507 | 549 | assert kv_stack.to_dict() == expected_stack |
508 | 550 | kv_stack = load_config_to_kv_config_stack(input_snake_case, LlmLoadModelConfig) |
509 | 551 | assert kv_stack.to_dict() == expected_stack |
| 552 | + del gpu_camelCase["ratio"] |
| 553 | + assert parse_server_config(expected_server_config) == input_camelCase |
510 | 554 |
|
511 | 555 |
|
512 | 556 | @pytest.mark.parametrize("config_dict", (PREDICTION_CONFIG, SC_PREDICTION_CONFIG)) |
|
0 commit comments