|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | | -from ..models import ModelConfig |
| 4 | +from abc import ABC, abstractmethod |
| 5 | +from enum import Enum |
| 6 | +from typing import Literal, TypeVar, overload |
| 7 | + |
| 8 | +from ..models import ModelConfig, ModelProvider |
5 | 9 | from ..sampler_params import SamplerType |
6 | 10 | from .type_helpers import get_sampler_params |
7 | | -from .visualization import display_model_configs_table, display_sampler_table |
| 11 | +from .visualization import display_model_configs_table, display_model_providers_table, display_sampler_table |
8 | 12 |
|
9 | 13 |
|
10 | | -class DataDesignerInfo: |
11 | | - def __init__(self, model_configs: list[ModelConfig] | None = None): |
12 | | - self._sampler_params = get_sampler_params() |
13 | | - self._model_configs = model_configs or [] |
| 14 | +class InfoType(str, Enum): |
| 15 | + SAMPLERS = "SAMPLERS" |
| 16 | + MODEL_CONFIGS = "MODEL_CONFIGS" |
| 17 | + MODEL_PROVIDERS = "MODEL_PROVIDERS" |
| 18 | + |
| 19 | + |
| 20 | +ConfigBuilderInfoType = Literal[InfoType.SAMPLERS, InfoType.MODEL_CONFIGS] |
| 21 | +DataDesignerInfoType = Literal[InfoType.MODEL_PROVIDERS] |
| 22 | +InfoTypeT = TypeVar("InfoTypeT", bound=InfoType) |
| 23 | + |
14 | 24 |
|
15 | | - @property |
16 | | - def sampler_table(self) -> None: |
17 | | - display_sampler_table(self._sampler_params) |
| 25 | +class InfoDisplay(ABC): |
| 26 | + """Base class for info display classes that provide type-safe display methods.""" |
18 | 27 |
|
19 | | - @property |
20 | | - def sampler_types(self) -> list[str]: |
21 | | - return [s.value for s in SamplerType] |
| 28 | + @abstractmethod |
| 29 | + def display(self, info_type: InfoTypeT, **kwargs) -> None: |
| 30 | + """Display information based on the provided info type. |
22 | 31 |
|
23 | | - def set_model_configs(self, model_configs: list[ModelConfig]) -> None: |
| 32 | + Args: |
| 33 | + info_type: Type of information to display. |
| 34 | + """ |
| 35 | + ... |
| 36 | + |
| 37 | + |
| 38 | +class ConfigBuilderInfo(InfoDisplay): |
| 39 | + def __init__(self, model_configs: list[ModelConfig]): |
| 40 | + self._sampler_params = get_sampler_params() |
24 | 41 | self._model_configs = model_configs |
25 | 42 |
|
26 | | - def display_sampler(self, sampler_type: SamplerType) -> None: |
27 | | - title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler" |
28 | | - display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title) |
| 43 | + @overload |
| 44 | + def display(self, info_type: Literal[InfoType.SAMPLERS], **kwargs) -> None: ... |
| 45 | + |
| 46 | + @overload |
| 47 | + def display(self, info_type: Literal[InfoType.MODEL_CONFIGS], **kwargs) -> None: ... |
| 48 | + |
| 49 | + def display(self, info_type: ConfigBuilderInfoType, **kwargs) -> None: |
| 50 | + """Display information based on the provided info type. |
| 51 | +
|
| 52 | + Args: |
| 53 | + info_type: Type of information to display. Only SAMPLERS and MODEL_CONFIGS are supported. |
| 54 | +
|
| 55 | + Raises: |
| 56 | + ValueError: If an unsupported info_type is provided. |
| 57 | + """ |
| 58 | + if info_type == InfoType.SAMPLERS: |
| 59 | + self._display_sampler_info(sampler_type=kwargs.get("sampler_type")) |
| 60 | + elif info_type == InfoType.MODEL_CONFIGS: |
| 61 | + display_model_configs_table(self._model_configs) |
| 62 | + else: |
| 63 | + raise ValueError( |
| 64 | + f"Unsupported info_type: {info_type!r}. " |
| 65 | + f"ConfigBuilderInfo only supports {InfoType.SAMPLERS.value!r} and {InfoType.MODEL_CONFIGS.value!r}." |
| 66 | + ) |
| 67 | + |
| 68 | + def _display_sampler_info(self, sampler_type: SamplerType | None) -> None: |
| 69 | + if sampler_type is not None: |
| 70 | + title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler" |
| 71 | + display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title) |
| 72 | + else: |
| 73 | + display_sampler_table(self._sampler_params) |
| 74 | + |
| 75 | + |
| 76 | +class InterfaceInfo(InfoDisplay): |
| 77 | + def __init__(self, model_providers: list[ModelProvider]): |
| 78 | + self._model_providers = model_providers |
| 79 | + |
| 80 | + @overload |
| 81 | + def display(self, info_type: Literal[InfoType.MODEL_PROVIDERS], **kwargs) -> None: ... |
| 82 | + |
| 83 | + def display(self, info_type: DataDesignerInfoType, **kwargs) -> None: |
| 84 | + """Display information based on the provided info type. |
| 85 | +
|
| 86 | + Args: |
| 87 | + info_type: Type of information to display. Only MODEL_PROVIDERS is supported. |
29 | 88 |
|
30 | | - def display_model_configs(self) -> None: |
31 | | - display_model_configs_table(model_configs=self._model_configs) |
| 89 | + Raises: |
| 90 | + ValueError: If an unsupported info_type is provided. |
| 91 | + """ |
| 92 | + if info_type == InfoType.MODEL_PROVIDERS: |
| 93 | + display_model_providers_table(self._model_providers) |
| 94 | + else: |
| 95 | + raise ValueError( |
| 96 | + f"Unsupported info_type: {info_type!r}. InterfaceInfo only supports {InfoType.MODEL_PROVIDERS.value!r}." |
| 97 | + ) |
0 commit comments