-
Notifications
You must be signed in to change notification settings - Fork 51
chore: (FTUE ) updated display pipeline for builder and interface like objects #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
446a00e
e8e22be
4afa753
07db807
74bdf6f
bddeb4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Generic, Protocol, TypeVar | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from .models import ModelConfig, ModelProvider | ||
| from .utils.constants import DEFAULT_NUM_RECORDS | ||
| from .utils.info import InterfaceInfo | ||
|
|
||
| if TYPE_CHECKING: | ||
| from .analysis.dataset_profiler import DatasetProfilerResults | ||
| from .config_builder import DataDesignerConfigBuilder | ||
| from .preview_results import PreviewResults | ||
|
|
||
|
|
||
| class ResultsProtocol(Protocol): | ||
| def load_analysis(self) -> DatasetProfilerResults: ... | ||
| def load_dataset(self) -> pd.DataFrame: ... | ||
|
|
||
|
|
||
| ResultsT = TypeVar("ResultsT", bound=ResultsProtocol) | ||
|
|
||
|
|
||
| class DataDesignerInterface(ABC, Generic[ResultsT]): | ||
| @abstractmethod | ||
| def create( | ||
| self, | ||
| config_builder: DataDesignerConfigBuilder, | ||
| *, | ||
| num_records: int = DEFAULT_NUM_RECORDS, | ||
| ) -> ResultsT: ... | ||
|
|
||
| @abstractmethod | ||
| def preview( | ||
| self, | ||
| config_builder: DataDesignerConfigBuilder, | ||
| *, | ||
| num_records: int = DEFAULT_NUM_RECORDS, | ||
| ) -> PreviewResults: ... | ||
|
|
||
| @abstractmethod | ||
| def get_default_model_configs(self) -> list[ModelConfig]: ... | ||
|
|
||
| @abstractmethod | ||
| def get_default_model_providers(self) -> list[ModelProvider]: ... | ||
|
|
||
| @property | ||
| @abstractmethod | ||
| def info(self) -> InterfaceInfo: ... | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,8 @@ | |
|
|
||
| from abc import ABC, abstractmethod | ||
| from enum import Enum | ||
| import logging | ||
| import os | ||
| from pathlib import Path | ||
| from typing import Any, Generic, List, Optional, TypeVar, Union | ||
|
|
||
|
|
@@ -12,9 +14,18 @@ | |
|
|
||
| from .base import ConfigBase | ||
| from .errors import InvalidConfigError | ||
| from .utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P | ||
| from .utils.constants import ( | ||
| MAX_TEMPERATURE, | ||
| MAX_TOP_P, | ||
| MIN_TEMPERATURE, | ||
| MIN_TOP_P, | ||
| NVIDIA_API_KEY_ENV_VAR_NAME, | ||
| NVIDIA_PROVIDER_NAME, | ||
| ) | ||
| from .utils.io_helpers import smart_load_yaml | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class Modality(str, Enum): | ||
| IMAGE = "image" | ||
|
|
@@ -204,9 +215,14 @@ class ModelConfig(ConfigBase): | |
| provider: Optional[str] = None | ||
|
|
||
|
|
||
| def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) -> list[ModelConfig]: | ||
| if model_configs is None: | ||
| return [] | ||
| class ModelProvider(ConfigBase): | ||
| name: str | ||
| endpoint: str | ||
| provider_type: str = "openai" | ||
| api_key: str | None = None | ||
|
|
||
|
|
||
| def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]: | ||
| if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs): | ||
| return model_configs | ||
| json_config = smart_load_yaml(model_configs) | ||
|
|
@@ -215,3 +231,51 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) | |
| "The list of model configs must be provided under model_configs in the configuration file." | ||
| ) | ||
| return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]] | ||
|
|
||
|
|
||
| def get_default_nvidia_model_configs() -> list[ModelConfig]: | ||
| if not get_nvidia_api_key(): | ||
| logger.warning( | ||
| "‼️🔑 'NVIDIA_API_KEY' environment variable is not set. Please set it to your API key from 'build.nvidia.com' if you want to use the default NVIDIA model configs." | ||
| ) | ||
| return [ | ||
| ModelConfig( | ||
| alias="text", | ||
| model="nvidia/nvidia-nemotron-nano-9b-v2", | ||
| provider=NVIDIA_PROVIDER_NAME, | ||
| inference_parameters=InferenceParameters( | ||
| temperature=0.85, | ||
| top_p=0.95, | ||
| ), | ||
| ), | ||
| ModelConfig( | ||
| alias="reasoning", | ||
| model="nvidia/llama-3.3-nemotron-super-49b-v1.5", | ||
|
||
| provider=NVIDIA_PROVIDER_NAME, | ||
| inference_parameters=InferenceParameters( | ||
| temperature=0.35, | ||
| top_p=0.95, | ||
| ), | ||
| ), | ||
| ModelConfig( | ||
| alias="vision", | ||
| model="nvidia/nemotron-nano-12b-v2-vl", | ||
| provider=NVIDIA_PROVIDER_NAME, | ||
| inference_parameters=InferenceParameters( | ||
| temperature=0.85, | ||
| top_p=0.95, | ||
| ), | ||
| ), | ||
| ] | ||
|
|
||
|
|
||
| def get_nvidia_api_key() -> Optional[str]: | ||
| return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME) | ||
|
|
||
|
|
||
| def get_default_nvidia_model_provider() -> ModelProvider: | ||
| return ModelProvider( | ||
| name=NVIDIA_PROVIDER_NAME, | ||
| endpoint="https://integrate.api.nvidia.com/v1", | ||
| api_key=NVIDIA_API_KEY_ENV_VAR_NAME, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,23 +1,97 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from enum import Enum | ||
| from typing import Literal, TypeVar, overload | ||
|
|
||
| from ..models import ModelConfig, ModelProvider | ||
| from ..sampler_params import SamplerType | ||
| from .type_helpers import get_sampler_params | ||
| from .visualization import display_sampler_table | ||
| from .visualization import display_model_configs_table, display_model_providers_table, display_sampler_table | ||
|
|
||
|
|
||
| class InfoType(str, Enum): | ||
| SAMPLERS = "SAMPLERS" | ||
| MODEL_CONFIGS = "MODEL_CONFIGS" | ||
| MODEL_PROVIDERS = "MODEL_PROVIDERS" | ||
|
|
||
|
|
||
| ConfigBuilderInfoType = Literal[InfoType.SAMPLERS, InfoType.MODEL_CONFIGS] | ||
| DataDesignerInfoType = Literal[InfoType.MODEL_PROVIDERS] | ||
| InfoTypeT = TypeVar("InfoTypeT", bound=InfoType) | ||
|
|
||
|
|
||
| class InfoDisplay(ABC): | ||
| """Base class for info display classes that provide type-safe display methods.""" | ||
|
|
||
| @abstractmethod | ||
| def display(self, info_type: InfoTypeT, **kwargs) -> None: | ||
| """Display information based on the provided info type. | ||
| Args: | ||
| info_type: Type of information to display. | ||
| """ | ||
| ... | ||
|
|
||
|
|
||
| class DataDesignerInfo: | ||
| def __init__(self): | ||
| class ConfigBuilderInfo(InfoDisplay): | ||
| def __init__(self, model_configs: list[ModelConfig]): | ||
| self._sampler_params = get_sampler_params() | ||
| self._model_configs = model_configs | ||
|
|
||
| @overload | ||
| def display(self, info_type: Literal[InfoType.SAMPLERS], **kwargs) -> None: ... | ||
|
|
||
| @overload | ||
| def display(self, info_type: Literal[InfoType.MODEL_CONFIGS], **kwargs) -> None: ... | ||
nabinchha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def display(self, info_type: ConfigBuilderInfoType, **kwargs) -> None: | ||
| """Display information based on the provided info type. | ||
| Args: | ||
| info_type: Type of information to display. Only SAMPLERS and MODEL_CONFIGS are supported. | ||
| Raises: | ||
| ValueError: If an unsupported info_type is provided. | ||
| """ | ||
| if info_type == InfoType.SAMPLERS: | ||
| self._display_sampler_info(sampler_type=kwargs.get("sampler_type")) | ||
| elif info_type == InfoType.MODEL_CONFIGS: | ||
| display_model_configs_table(self._model_configs) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported info_type: {str(info_type)!r}. " | ||
| f"ConfigBuilderInfo only supports {InfoType.SAMPLERS.value!r} and {InfoType.MODEL_CONFIGS.value!r}." | ||
| ) | ||
|
|
||
| def _display_sampler_info(self, sampler_type: SamplerType | None) -> None: | ||
| if sampler_type is not None: | ||
| title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler" | ||
| display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title) | ||
| else: | ||
| display_sampler_table(self._sampler_params) | ||
|
|
||
|
|
||
| class InterfaceInfo(InfoDisplay): | ||
| def __init__(self, model_providers: list[ModelProvider]): | ||
| self._model_providers = model_providers | ||
|
|
||
| @overload | ||
| def display(self, info_type: Literal[InfoType.MODEL_PROVIDERS], **kwargs) -> None: ... | ||
nabinchha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @property | ||
| def sampler_table(self) -> None: | ||
| display_sampler_table(self._sampler_params) | ||
| def display(self, info_type: DataDesignerInfoType, **kwargs) -> None: | ||
| """Display information based on the provided info type. | ||
| @property | ||
| def sampler_types(self) -> list[str]: | ||
| return [s.value for s in SamplerType] | ||
| Args: | ||
| info_type: Type of information to display. Only MODEL_PROVIDERS is supported. | ||
| def display_sampler(self, sampler_type: SamplerType) -> None: | ||
| title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler" | ||
| display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title) | ||
| Raises: | ||
| ValueError: If an unsupported info_type is provided. | ||
| """ | ||
| if info_type == InfoType.MODEL_PROVIDERS: | ||
| display_model_providers_table(self._model_providers) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported info_type: {str(info_type)!r}. InterfaceInfo only supports {InfoType.MODEL_PROVIDERS.value!r}." | ||
| ) | ||



There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mikeknep just want to make sure you are are of these changes for the nmp dd client.