2525
2626
2727class Modality (str , Enum ):
28+ """Supported modality types for multimodal model data."""
29+
2830 IMAGE = "image"
2931
3032
3133class ModalityDataType (str , Enum ):
34+ """Data type formats for multimodal data."""
35+
3236 URL = "url"
3337 BASE64 = "base64"
3438
3539
3640class ImageFormat (str , Enum ):
41+ """Supported image formats for image modality."""
42+
3743 PNG = "png"
3844 JPG = "jpg"
3945 JPEG = "jpeg"
@@ -42,6 +48,8 @@ class ImageFormat(str, Enum):
4248
4349
4450class DistributionType (str , Enum ):
51+ """Types of distributions for sampling inference parameters."""
52+
4553 UNIFORM = "uniform"
4654 MANUAL = "manual"
4755
@@ -56,10 +64,27 @@ def get_context(self, record: dict) -> dict[str, Any]: ...
5664
5765
5866class ImageContext (ModalityContext ):
67+ """Configuration for providing image context to multimodal models.
68+
69+ Attributes:
70+ modality: The modality type (always "image").
71+ column_name: Name of the column containing image data.
72+ data_type: Format of the image data ("url" or "base64").
73+ image_format: Image format (required for base64 data).
74+ """
75+
5976 modality : Modality = Modality .IMAGE
6077 image_format : Optional [ImageFormat ] = None
6178
6279 def get_context (self , record : dict ) -> dict [str , Any ]:
80+ """Get the context for the image modality.
81+
82+ Args:
83+ record: The record containing the image data.
84+
85+ Returns:
86+ The context for the image modality.
87+ """
6388 context = dict (type = "image_url" )
6489 context_value = record [self .column_name ]
6590 if self .data_type == ModalityDataType .URL :
@@ -90,6 +115,13 @@ def sample(self) -> float: ...
90115
91116
92117class ManualDistributionParams (ConfigBase ):
118+ """Parameters for manual distribution sampling.
119+
120+ Attributes:
121+ values: List of possible values to sample from.
122+ weights: Optional list of weights for each value. If not provided, all values have equal probability.
123+ """
124+
93125 values : List [float ] = Field (min_length = 1 )
94126 weights : Optional [List [float ]] = None
95127
@@ -107,14 +139,36 @@ def _validate_equal_lengths(self) -> Self:
107139
108140
109141class ManualDistribution (Distribution [ManualDistributionParams ]):
142+ """Manual (discrete) distribution for sampling inference parameters.
143+
144+ Samples from a discrete set of values with optional weights. Useful for testing
145+ specific values or creating custom probability distributions for temperature or top_p.
146+
147+ Attributes:
148+ distribution_type: Type of distribution ("manual").
149+ params: Distribution parameters (values, weights).
150+ """
151+
110152 distribution_type : Optional [DistributionType ] = "manual"
111153 params : ManualDistributionParams
112154
113155 def sample (self ) -> float :
156+ """Sample a value from the manual distribution.
157+
158+ Returns:
159+ A float value sampled from the manual distribution.
160+ """
114161 return float (np .random .choice (self .params .values , p = self .params .weights ))
115162
116163
117164class UniformDistributionParams (ConfigBase ):
165+ """Parameters for uniform distribution sampling.
166+
167+ Attributes:
168+ low: Lower bound (inclusive).
169+ high: Upper bound (exclusive).
170+ """
171+
118172 low : float
119173 high : float
120174
@@ -126,17 +180,43 @@ def _validate_low_lt_high(self) -> Self:
126180
127181
128182class UniformDistribution (Distribution [UniformDistributionParams ]):
183+ """Uniform distribution for sampling inference parameters.
184+
185+ Samples values uniformly between low and high bounds. Useful for exploring
186+ a continuous range of values for temperature or top_p.
187+
188+ Attributes:
189+ distribution_type: Type of distribution ("uniform").
190+ params: Distribution parameters (low, high).
191+ """
192+
129193 distribution_type : Optional [DistributionType ] = "uniform"
130194 params : UniformDistributionParams
131195
132196 def sample (self ) -> float :
197+ """Sample a value from the uniform distribution.
198+
199+ Returns:
200+ A float value sampled from the uniform distribution.
201+ """
133202 return float (np .random .uniform (low = self .params .low , high = self .params .high , size = 1 )[0 ])
134203
135204
136205DistributionT : TypeAlias = Union [UniformDistribution , ManualDistribution ]
137206
138207
139208class InferenceParameters (ConfigBase ):
209+ """Configuration for LLM inference parameters.
210+
211+ Attributes:
212+ temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
213+ top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
214+ max_tokens: Maximum number of tokens (includes both input and output tokens).
215+ max_parallel_requests: Maximum number of parallel requests to the model API.
216+ timeout: Timeout in seconds for each request.
217+ extra_body: Additional parameters to pass to the model API.
218+ """
219+
140220 temperature : Optional [Union [float , DistributionT ]] = None
141221 top_p : Optional [Union [float , DistributionT ]] = None
142222 max_tokens : Optional [int ] = Field (default = None , ge = 1 )
@@ -146,6 +226,11 @@ class InferenceParameters(ConfigBase):
146226
147227 @property
148228 def generate_kwargs (self ) -> dict [str , Union [float , int ]]:
229+ """Get the generate kwargs for the inference parameters.
230+
231+ Returns:
232+ A dictionary of the generate kwargs.
233+ """
149234 result = {}
150235 if self .temperature is not None :
151236 result ["temperature" ] = (
@@ -206,13 +291,32 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) -
206291
207292
208293class ModelConfig (ConfigBase ):
294+ """Configuration for a model used for generation.
295+
296+ Attributes:
297+ alias: User-defined alias to reference in column configurations.
298+ model: Model identifier (e.g., from build.nvidia.com or other providers).
299+ inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
300+ provider: Optional model provider name if using custom providers.
301+ """
302+
209303 alias : str
210304 model : str
211305 inference_parameters : InferenceParameters = Field (default_factory = InferenceParameters )
212306 provider : Optional [str ] = None
213307
214308
215309class ModelProvider (ConfigBase ):
310+ """Configuration for a custom model provider.
311+
312+ Attributes:
313+ name: Name of the model provider.
314+ endpoint: API endpoint URL for the provider.
315+ provider_type: Provider type (default: "openai"). Determines the API format to use.
316+ api_key: Optional API key for authentication.
317+ extra_body: Additional parameters to pass in API requests.
318+ """
319+
216320 name : str
217321 endpoint : str
218322 provider_type : str = "openai"
0 commit comments