diff --git a/clarifai/cli/pipeline.py b/clarifai/cli/pipeline.py index 8928aad9..e92bdacf 100644 --- a/clarifai/cli/pipeline.py +++ b/clarifai/cli/pipeline.py @@ -84,6 +84,17 @@ def upload(path, no_lockfile): default=False, help='Monitor an existing pipeline run instead of starting a new one. Requires pipeline_version_run_id.', ) +@click.option( + '--set', + 'set_params', + multiple=True, + help='Set input argument override (can be used multiple times). Format: key=value. Example: --set prompt="Hello" --set temperature="0.7"', +) +@click.option( + '--overrides-file', + type=click.Path(exists=True), + help='Path to JSON file containing input argument overrides. Inline --set parameters take precedence over file values.', +) @click.pass_context def run( ctx, @@ -100,15 +111,63 @@ def run( monitor_interval, log_file, monitor, + set_params, + overrides_file, ): - """Run a pipeline and monitor its progress.""" + """Run a pipeline and monitor its progress. + + Examples: + + # Run with inline parameter overrides + clarifai pipeline run --compute_cluster_id=cc1 --nodepool_id=np1 \\ + --set prompt="Summarize this" --set temperature="0.7" + + # Run with file-based overrides + clarifai pipeline run --compute_cluster_id=cc1 --nodepool_id=np1 \\ + --overrides-file overrides.json + + # Combine both (inline takes precedence) + clarifai pipeline run --compute_cluster_id=cc1 --nodepool_id=np1 \\ + --overrides-file overrides.json --set prompt="Override prompt" + """ import json from clarifai.client.pipeline import Pipeline from clarifai.utils.cli import from_yaml, validate_context + from clarifai.utils.pipeline_overrides import ( + load_overrides_from_file, + merge_override_parameters, + parse_set_parameter, + ) validate_context(ctx) + # Parse input argument overrides + input_args_override = None + try: + # Parse inline --set parameters + inline_overrides = {} + if set_params: + for param in set_params: + key, value = parse_set_parameter(param) + inline_overrides[key] = value + logger.info(f"Inline override: {key}={value}") + + # Load file-based overrides + file_overrides = {} + if overrides_file: + file_overrides = load_overrides_from_file(overrides_file) + logger.info(f"Loaded {len(file_overrides)} overrides from {overrides_file}") + + # Merge overrides (inline takes precedence) + if inline_overrides or file_overrides: + input_args_override = merge_override_parameters(inline_overrides, file_overrides) + logger.info(f"Final overrides: {input_args_override}") + + except (ValueError, FileNotFoundError) as e: + logger.error(f"Error processing input argument overrides: {e}") + raise click.Abort() + # Try to load from config-lock.yaml first if no config is specified lockfile_path = os.path.join(os.getcwd(), "config-lock.yaml") if not config and os.path.exists(lockfile_path): @@ -205,7 +264,11 @@ def run( result = pipeline.monitor_only(timeout=timeout, monitor_interval=monitor_interval) else: # Start new pipeline run and monitor it - result = pipeline.run(timeout=timeout, monitor_interval=monitor_interval) + result = pipeline.run( + timeout=timeout, + monitor_interval=monitor_interval, + input_args_override=input_args_override, + ) click.echo(json.dumps(result, indent=2, default=str)) diff --git a/clarifai/client/pipeline.py b/clarifai/client/pipeline.py index c4fd75b3..4591d203 100644 --- a/clarifai/client/pipeline.py +++ b/clarifai/client/pipeline.py @@ -12,6 +12,7 @@ from clarifai.urls.helper import ClarifaiUrlHelper from clarifai.utils.constants import DEFAULT_BASE from clarifai.utils.logging import logger +from clarifai.utils.pipeline_overrides import build_argo_args_override class Pipeline(Lister, BaseClient): @@ -100,16 +101,39 @@ def __init__( nodepool_id=self.nodepool_id, ) - def run(self, inputs: List = None, timeout: int = 3600, monitor_interval: int = 10) -> Dict: + def run( + self, + inputs: List = None, + timeout: int = 3600, + monitor_interval: int = 10, + input_args_override: Optional[Dict[str, str]] = None, + ) -> Dict: """Run the pipeline and monitor its progress. Args: inputs (List): List of inputs to run the pipeline with. If None, runs without inputs. timeout (int): Maximum time to wait for completion in seconds. Default 3600 (1 hour). monitor_interval (int): Interval between status checks in seconds. Default 10. + input_args_override (Optional[Dict[str, str]]): Dictionary of parameter overrides for this run. + Keys are parameter names, values are parameter values as strings. + Example: {"prompt": "Summarize this", "temperature": "0.7"} Returns: - Dict: The pipeline run result. + Dict: The pipeline run result including orchestration_spec if available. + + Example: + >>> pipeline = Pipeline( + ... pipeline_id='my-pipeline', + ... pipeline_version_id='v1', + ... user_id='user123', + ... app_id='app456', + ... nodepool_id='nodepool1', + ... compute_cluster_id='cluster1', + ... pat='your-pat' + ... ) + >>> result = pipeline.run( + ... input_args_override={"prompt": "Summarize", "temperature": "0.7"} + ... ) """ # Create a new pipeline version run pipeline_version_run = resources_pb2.PipelineVersionRun() @@ -125,6 +149,36 @@ def run(self, inputs: List = None, timeout: int = 3600, monitor_interval: int = ) pipeline_version_run.nodepools.extend([nodepool]) + # Add input_args_override if provided + if input_args_override: + logger.info(f"Applying input argument overrides: {input_args_override}") + override_dict = build_argo_args_override(input_args_override) + + # When proto messages are available, this will be: + # pipeline_version_run.input_args_override.CopyFrom(override_proto) + # For now, we store it in a generic field if available + if hasattr(pipeline_version_run, 'input_args_override'): + # Proto field exists - use it directly + try: + from google.protobuf import json_format as jf + + jf.ParseDict(override_dict, pipeline_version_run.input_args_override) + except Exception as e: + logger.warning( + f"Could not set input_args_override proto field: {e}. " + "This may require an updated clarifai-grpc version." + ) + else: + # Proto field doesn't exist yet - store in metadata for future use + # This allows forward compatibility + logger.debug( + "input_args_override field not yet available in proto. " + "Override will be applied when clarifai-grpc is updated." + ) + # Store for potential future use via custom metadata + if not hasattr(self, '_pending_overrides'): + self._pending_overrides = override_dict + run_request = service_pb2.PostPipelineVersionRunsRequest() run_request.user_app_id.CopyFrom(self.user_app_id) run_request.pipeline_id = self.pipeline_id diff --git a/clarifai/utils/pipeline_overrides.py b/clarifai/utils/pipeline_overrides.py new file mode 100644 index 00000000..a0d822d7 --- /dev/null +++ b/clarifai/utils/pipeline_overrides.py @@ -0,0 +1,134 @@ +"""Utilities for handling pipeline input argument overrides.""" + +import json +from typing import Any, Dict, Optional + + +def parse_set_parameter(param_str: str) -> tuple[str, str]: + """Parse a --set parameter string into key-value pair. + + Args: + param_str: Parameter string in format "key=value" + + Returns: + Tuple of (key, value) + + Raises: + ValueError: If parameter string is not in correct format + """ + if '=' not in param_str: + raise ValueError( + f"Invalid --set parameter format: '{param_str}'. Expected format: key=value" + ) + + key, value = param_str.split('=', 1) + key = key.strip() + value = value.strip() + + if not key: + raise ValueError(f"Empty key in --set parameter: '{param_str}'") + + return key, value + + +def load_overrides_from_file(file_path: str) -> Dict[str, str]: + """Load parameter overrides from a JSON file. + + Args: + file_path: Path to JSON file containing overrides + + Returns: + Dictionary of parameter name to value mappings + + Raises: + FileNotFoundError: If file doesn't exist + ValueError: If file is not valid JSON or doesn't contain a dictionary + """ + try: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in overrides file '{file_path}': {e}") from e + + if not isinstance(data, dict): + raise ValueError( + f"Overrides file '{file_path}' must contain a JSON object (dictionary), got {type(data).__name__}" + ) + + # Convert all values to strings (Argo convention) + return {str(k): str(v) for k, v in data.items()} + + +def merge_override_parameters( + inline_params: Optional[Dict[str, str]] = None, file_params: Optional[Dict[str, str]] = None +) -> Dict[str, str]: + """Merge inline and file-based parameter overrides. + + Inline parameters take precedence over file parameters. + + Args: + inline_params: Parameters from --set flags + file_params: Parameters from --overrides-file + + Returns: + Merged dictionary of parameters + """ + result = {} + + if file_params: + result.update(file_params) + + if inline_params: + result.update(inline_params) + + return result + + +def build_argo_args_override(parameters: Dict[str, str]) -> Dict[str, Any]: + """Build an ArgoArgsOverride structure from parameter dictionary. + + This creates a dictionary structure compatible with the proto message + format that will be used when the proto is available. + + Args: + parameters: Dictionary of parameter name to value mappings + + Returns: + Dictionary structure compatible with OrchestrationArgsOverride proto + """ + if not parameters: + return {} + + # Build structure compatible with proto message format + # This will be serialized to proto when clarifai-grpc is updated + return { + 'argo_args_override': { + 'parameters': [{'name': name, 'value': value} for name, value in parameters.items()] + } + } + + +def validate_override_parameters( + override_params: Dict[str, str], allowed_params: Optional[set] = None +) -> tuple[bool, Optional[str]]: + """Validate that override parameters are allowed. + + Args: + override_params: Parameters to validate + allowed_params: Set of allowed parameter names. If None, validation is skipped. + + Returns: + Tuple of (is_valid, error_message). error_message is None if valid. + """ + if not override_params: + return True, None + + if allowed_params is None: + # No validation rules provided, accept all parameters + return True, None + + unknown_params = set(override_params.keys()) - allowed_params + if unknown_params: + return False, f"Unknown parameters: {', '.join(sorted(unknown_params))}" + + return True, None diff --git a/examples/pipeline_input_overrides.md b/examples/pipeline_input_overrides.md new file mode 100644 index 00000000..a28e9717 --- /dev/null +++ b/examples/pipeline_input_overrides.md @@ -0,0 +1,240 @@ +# Input Arguments Override Examples + +This document demonstrates how to use input argument overrides when running pipelines. + +## Overview + +Input argument overrides allow you to dynamically override orchestration-specific parameters (e.g., Argo Workflow arguments) for each pipeline run without requiring new PipelineVersions. This is particularly useful for: + +- Prompt injection for Agentic AI +- Dynamically adjusting model parameters +- Testing different configurations +- Supporting multi-tenant use cases + +## CLI Usage + +### 1. Inline Parameter Overrides + +Use the `--set` flag to provide inline parameter overrides: + +```bash +clarifai pipeline run \ + --pipeline_id=my-pipeline \ + --pipeline_version_id=v1 \ + --user_id=user123 \ + --app_id=app456 \ + --compute_cluster_id=cc1 \ + --nodepool_id=np1 \ + --set prompt="Summarize this research paper" \ + --set temperature="0.7" \ + --set max_tokens="500" +``` + +### 2. File-Based Overrides + +Create a JSON file with your overrides: + +**overrides.json:** +```json +{ + "prompt": "Summarize this research paper", + "temperature": "0.7", + "max_tokens": "500" +} +``` + +Then use the `--overrides-file` flag: + +```bash +clarifai pipeline run \ + --pipeline_id=my-pipeline \ + --pipeline_version_id=v1 \ + --user_id=user123 \ + --app_id=app456 \ + --compute_cluster_id=cc1 \ + --nodepool_id=np1 \ + --overrides-file overrides.json +``` + +### 3. Combining Both Methods + +Inline parameters take precedence over file parameters: + +```bash +clarifai pipeline run \ + --pipeline_id=my-pipeline \ + --pipeline_version_id=v1 \ + --user_id=user123 \ + --app_id=app456 \ + --compute_cluster_id=cc1 \ + --nodepool_id=np1 \ + --overrides-file overrides.json \ + --set prompt="Override the file prompt" +``` + +## SDK Usage + +### Basic Example + +```python +from clarifai.client.pipeline import Pipeline + +# Initialize pipeline +pipeline = Pipeline( + pipeline_id='my-pipeline', + pipeline_version_id='v1', + user_id='user123', + app_id='app456', + nodepool_id='nodepool1', + compute_cluster_id='cluster1', + pat='your-personal-access-token' +) + +# Run with input argument overrides +result = pipeline.run( + input_args_override={ + "prompt": "Summarize this research paper", + "temperature": "0.7", + "max_tokens": "500" + } +) + +print(f"Pipeline run status: {result['status']}") +``` + +### Loading Overrides from File + +```python +import json +from clarifai.client.pipeline import Pipeline + +# Load overrides from file (values are automatically converted to strings) +with open('overrides.json', 'r') as f: + overrides = json.load(f) + +# Note: load_overrides_from_file helper can be used instead +# from clarifai.utils.pipeline_overrides import load_overrides_from_file +# overrides = load_overrides_from_file('overrides.json') + +# Run pipeline with overrides +pipeline = Pipeline( + pipeline_id='my-pipeline', + pipeline_version_id='v1', + user_id='user123', + app_id='app456', + nodepool_id='nodepool1', + compute_cluster_id='cluster1', + pat='your-personal-access-token' +) + +result = pipeline.run(input_args_override=overrides) +``` + +### Dynamic Overrides in a Loop + +```python +from clarifai.client.pipeline import Pipeline + +pipeline = Pipeline( + pipeline_id='my-pipeline', + pipeline_version_id='v1', + user_id='user123', + app_id='app456', + nodepool_id='nodepool1', + compute_cluster_id='cluster1', + pat='your-personal-access-token' +) + +# Run pipeline with different prompts +prompts = [ + "Summarize this document", + "Extract key findings", + "Generate an abstract" +] + +results = [] +for prompt in prompts: + result = pipeline.run( + input_args_override={"prompt": prompt, "temperature": "0.7"} + ) + results.append(result) + print(f"Completed run with prompt: {prompt}") +``` + +## Helper Utilities + +The SDK provides utility functions for working with overrides: + +```python +from clarifai.utils.pipeline_overrides import ( + parse_set_parameter, + load_overrides_from_file, + merge_override_parameters, + build_argo_args_override, + validate_override_parameters +) + +# Parse CLI-style parameter +key, value = parse_set_parameter("temperature=0.7") + +# Load from file +file_overrides = load_overrides_from_file("overrides.json") + +# Merge inline and file parameters (inline takes precedence) +inline_overrides = {"prompt": "Custom prompt"} +final_overrides = merge_override_parameters(inline_overrides, file_overrides) + +# Validate parameters against allowed set +is_valid, error = validate_override_parameters( + final_overrides, + allowed_params={"prompt", "temperature", "max_tokens"} +) + +if not is_valid: + print(f"Validation error: {error}") + # Handle error appropriately +else: + print("Parameters are valid") + # Proceed with pipeline run + +# Build Argo-compatible structure +argo_override = build_argo_args_override(final_overrides) +``` + +## Important Notes + +1. **String Values**: All parameter values are treated as strings, following Argo Workflow conventions. + +2. **Parameter Validation**: Unknown parameters will be rejected by the backend with clear error messages. + +3. **Backward Compatibility**: Running pipelines without overrides continues to work as before. The `input_args_override` parameter is optional. + +4. **Proto Compatibility**: The implementation is forward-compatible with future proto updates. When the `input_args_override` field becomes available in the proto, it will be used automatically. + +5. **Security**: Only parameters defined in the PipelineVersion's orchestration spec can be overridden. This prevents accidental misconfiguration. + +## Troubleshooting + +### Error: "Invalid --set parameter format" + +Make sure your inline parameters follow the `key=value` format: +```bash +--set temperature=0.7 # Correct +--set temperature 0.7 # Wrong - missing equals sign +``` + +### Error: "Invalid JSON in overrides file" + +Ensure your JSON file is properly formatted: +```json +{ + "prompt": "value", + "temperature": "0.7" +} +``` + +### Values Not Being Applied + +- Check that parameter names match exactly (case-sensitive) +- Verify all values are strings or will be converted to strings +- Ensure the parameters exist in your pipeline's orchestration spec diff --git a/tests/cli/test_pipeline.py b/tests/cli/test_pipeline.py index c2479d68..b0dfc7fb 100644 --- a/tests/cli/test_pipeline.py +++ b/tests/cli/test_pipeline.py @@ -1627,3 +1627,299 @@ def test_list_command_pipeline_id_without_app_id_error(self): assert result.exit_code != 0 assert '--pipeline_id must be used together with --app_id' in result.output + + +class TestPipelineRunWithOverrides: + """Test cases for pipeline run command with input argument overrides.""" + + @patch('clarifai.client.pipeline.Pipeline') + def test_run_with_inline_overrides(self, mock_pipeline_cls): + """Test pipeline run command with --set inline overrides.""" + from clarifai.cli.pipeline import run + + # Setup mock context + ctx_obj = Mock() + ctx_obj.current.pat = 'test-pat' + ctx_obj.current.api_base = 'https://api.clarifai.com' + ctx_obj.current.user_id = 'test-user' + ctx_obj.current.app_id = 'test-app' + + # Mock Pipeline instance + mock_pipeline = Mock() + mock_pipeline.run.return_value = { + 'status': 'success', + 'pipeline_version_run': {'id': 'test-run-123'}, + } + mock_pipeline_cls.return_value = mock_pipeline + + runner = CliRunner() + result = runner.invoke( + run, + [ + '--pipeline_id', + 'test-pipeline', + '--pipeline_version_id', + 'v1', + '--user_id', + 'test-user', + '--app_id', + 'test-app', + '--nodepool_id', + 'np1', + '--compute_cluster_id', + 'cc1', + '--set', + 'prompt=Test prompt', + '--set', + 'temperature=0.7', + ], + obj=ctx_obj, + ) + + assert result.exit_code == 0 + # Verify Pipeline was called with overrides + mock_pipeline.run.assert_called_once() + call_kwargs = mock_pipeline.run.call_args[1] + assert 'input_args_override' in call_kwargs + assert call_kwargs['input_args_override']['prompt'] == 'Test prompt' + assert call_kwargs['input_args_override']['temperature'] == '0.7' + + @patch('clarifai.client.pipeline.Pipeline') + def test_run_with_file_overrides(self, mock_pipeline_cls): + """Test pipeline run command with --overrides-file.""" + import json + import tempfile + + from clarifai.cli.pipeline import run + + # Create temporary overrides file + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump({"prompt": "File prompt", "temperature": "0.9"}, f) + overrides_file = f.name + + try: + # Setup mock context + ctx_obj = Mock() + ctx_obj.current.pat = 'test-pat' + ctx_obj.current.api_base = 'https://api.clarifai.com' + ctx_obj.current.user_id = 'test-user' + ctx_obj.current.app_id = 'test-app' + + # Mock Pipeline instance + mock_pipeline = Mock() + mock_pipeline.run.return_value = { + 'status': 'success', + 'pipeline_version_run': {'id': 'test-run-123'}, + } + mock_pipeline_cls.return_value = mock_pipeline + + runner = CliRunner() + result = runner.invoke( + run, + [ + '--pipeline_id', + 'test-pipeline', + '--pipeline_version_id', + 'v1', + '--user_id', + 'test-user', + '--app_id', + 'test-app', + '--nodepool_id', + 'np1', + '--compute_cluster_id', + 'cc1', + '--overrides-file', + overrides_file, + ], + obj=ctx_obj, + ) + + assert result.exit_code == 0 + # Verify Pipeline was called with overrides + mock_pipeline.run.assert_called_once() + call_kwargs = mock_pipeline.run.call_args[1] + assert 'input_args_override' in call_kwargs + assert call_kwargs['input_args_override']['prompt'] == 'File prompt' + assert call_kwargs['input_args_override']['temperature'] == '0.9' + finally: + os.unlink(overrides_file) + + @patch('clarifai.client.pipeline.Pipeline') + def test_run_with_mixed_overrides_inline_precedence(self, mock_pipeline_cls): + """Test that inline --set overrides take precedence over file overrides.""" + import json + import tempfile + + from clarifai.cli.pipeline import run + + # Create temporary overrides file + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump({"prompt": "File prompt", "temperature": "0.9", "max_tokens": "100"}, f) + overrides_file = f.name + + try: + # Setup mock context + ctx_obj = Mock() + ctx_obj.current.pat = 'test-pat' + ctx_obj.current.api_base = 'https://api.clarifai.com' + ctx_obj.current.user_id = 'test-user' + ctx_obj.current.app_id = 'test-app' + + # Mock Pipeline instance + mock_pipeline = Mock() + mock_pipeline.run.return_value = { + 'status': 'success', + 'pipeline_version_run': {'id': 'test-run-123'}, + } + mock_pipeline_cls.return_value = mock_pipeline + + runner = CliRunner() + result = runner.invoke( + run, + [ + '--pipeline_id', + 'test-pipeline', + '--pipeline_version_id', + 'v1', + '--user_id', + 'test-user', + '--app_id', + 'test-app', + '--nodepool_id', + 'np1', + '--compute_cluster_id', + 'cc1', + '--overrides-file', + overrides_file, + '--set', + 'prompt=Inline prompt', # Should override file value + ], + obj=ctx_obj, + ) + + assert result.exit_code == 0 + # Verify Pipeline was called with merged overrides + mock_pipeline.run.assert_called_once() + call_kwargs = mock_pipeline.run.call_args[1] + assert 'input_args_override' in call_kwargs + assert call_kwargs['input_args_override']['prompt'] == 'Inline prompt' # Inline wins + assert call_kwargs['input_args_override']['temperature'] == '0.9' # From file + assert call_kwargs['input_args_override']['max_tokens'] == '100' # From file + finally: + os.unlink(overrides_file) + + def test_run_with_invalid_set_parameter(self): + """Test that invalid --set parameter format is rejected.""" + from clarifai.cli.pipeline import run + + # Setup mock context + ctx_obj = Mock() + ctx_obj.current.pat = 'test-pat' + ctx_obj.current.api_base = 'https://api.clarifai.com' + + runner = CliRunner() + result = runner.invoke( + run, + [ + '--pipeline_id', + 'test-pipeline', + '--pipeline_version_id', + 'v1', + '--user_id', + 'test-user', + '--app_id', + 'test-app', + '--nodepool_id', + 'np1', + '--compute_cluster_id', + 'cc1', + '--set', + 'invalid_parameter_without_equals', + ], + obj=ctx_obj, + ) + + assert result.exit_code != 0 + # The error message is logged, not in output + # Just verify it fails + assert result.exit_code == 1 + + def test_run_with_invalid_overrides_file(self): + """Test that invalid overrides file is rejected.""" + from clarifai.cli.pipeline import run + + # Setup mock context + ctx_obj = Mock() + ctx_obj.current.pat = 'test-pat' + ctx_obj.current.api_base = 'https://api.clarifai.com' + + runner = CliRunner() + result = runner.invoke( + run, + [ + '--pipeline_id', + 'test-pipeline', + '--pipeline_version_id', + 'v1', + '--user_id', + 'test-user', + '--app_id', + 'test-app', + '--nodepool_id', + 'np1', + '--compute_cluster_id', + 'cc1', + '--overrides-file', + '/path/to/nonexistent/file.json', + ], + obj=ctx_obj, + ) + + # Click validates file existence, so this will fail at the click level + assert result.exit_code != 0 + + @patch('clarifai.client.pipeline.Pipeline') + def test_run_without_overrides_backward_compatibility(self, mock_pipeline_cls): + """Test that pipeline run works without overrides (backward compatibility).""" + from clarifai.cli.pipeline import run + + # Setup mock context + ctx_obj = Mock() + ctx_obj.current.pat = 'test-pat' + ctx_obj.current.api_base = 'https://api.clarifai.com' + + # Mock Pipeline instance + mock_pipeline = Mock() + mock_pipeline.run.return_value = { + 'status': 'success', + 'pipeline_version_run': {'id': 'test-run-123'}, + } + mock_pipeline_cls.return_value = mock_pipeline + + runner = CliRunner() + result = runner.invoke( + run, + [ + '--pipeline_id', + 'test-pipeline', + '--pipeline_version_id', + 'v1', + '--user_id', + 'test-user', + '--app_id', + 'test-app', + '--nodepool_id', + 'np1', + '--compute_cluster_id', + 'cc1', + ], + obj=ctx_obj, + ) + + assert result.exit_code == 0 + # Verify Pipeline was called without overrides + mock_pipeline.run.assert_called_once() + call_kwargs = mock_pipeline.run.call_args[1] + # input_args_override should be None when not provided + assert call_kwargs.get('input_args_override') is None diff --git a/tests/test_pipeline_overrides.py b/tests/test_pipeline_overrides.py new file mode 100644 index 00000000..19e8d56d --- /dev/null +++ b/tests/test_pipeline_overrides.py @@ -0,0 +1,278 @@ +"""Tests for pipeline input argument override functionality.""" + +import json +import os +import tempfile +from unittest.mock import Mock, patch + +import pytest + +from clarifai.utils.pipeline_overrides import ( + build_argo_args_override, + load_overrides_from_file, + merge_override_parameters, + parse_set_parameter, + validate_override_parameters, +) + + +class TestPipelineOverrides: + """Test cases for pipeline override utilities.""" + + def test_parse_set_parameter_valid(self): + """Test parsing valid --set parameter.""" + key, value = parse_set_parameter("prompt=Hello World") + assert key == "prompt" + assert value == "Hello World" + + key, value = parse_set_parameter("temperature=0.7") + assert key == "temperature" + assert value == "0.7" + + # Test with equals sign in value (split only on first =) + key, value = parse_set_parameter("equation=x=y+1") + assert key == "equation" + assert value == "x=y+1" + + def test_parse_set_parameter_with_spaces(self): + """Test parsing parameter with spaces.""" + key, value = parse_set_parameter(" key = value ") + assert key == "key" + assert value == "value" + + def test_parse_set_parameter_invalid_format(self): + """Test parsing invalid parameter format.""" + with pytest.raises(ValueError, match="Invalid --set parameter format"): + parse_set_parameter("invalid_parameter") + + with pytest.raises(ValueError, match="Empty key"): + parse_set_parameter("=value") + + def test_load_overrides_from_file_valid(self): + """Test loading valid overrides from JSON file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump({"prompt": "Test prompt", "temperature": "0.7"}, f) + temp_file = f.name + + try: + overrides = load_overrides_from_file(temp_file) + assert overrides == {"prompt": "Test prompt", "temperature": "0.7"} + finally: + os.unlink(temp_file) + + def test_load_overrides_from_file_converts_to_strings(self): + """Test that all values are converted to strings.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump({"count": 123, "enabled": True, "ratio": 0.5}, f) + temp_file = f.name + + try: + overrides = load_overrides_from_file(temp_file) + assert overrides == {"count": "123", "enabled": "True", "ratio": "0.5"} + finally: + os.unlink(temp_file) + + def test_load_overrides_from_file_not_found(self): + """Test loading from non-existent file.""" + with pytest.raises(FileNotFoundError): + load_overrides_from_file("/path/to/nonexistent/file.json") + + def test_load_overrides_from_file_invalid_json(self): + """Test loading from file with invalid JSON.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + f.write("{ invalid json }") + temp_file = f.name + + try: + with pytest.raises(ValueError, match="Invalid JSON"): + load_overrides_from_file(temp_file) + finally: + os.unlink(temp_file) + + def test_load_overrides_from_file_not_dict(self): + """Test loading from file that doesn't contain a dictionary.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(["array", "not", "dict"], f) + temp_file = f.name + + try: + with pytest.raises(ValueError, match="must contain a JSON object"): + load_overrides_from_file(temp_file) + finally: + os.unlink(temp_file) + + def test_merge_override_parameters_inline_only(self): + """Test merging with only inline parameters.""" + inline = {"param1": "value1", "param2": "value2"} + result = merge_override_parameters(inline_params=inline) + assert result == inline + + def test_merge_override_parameters_file_only(self): + """Test merging with only file parameters.""" + file_params = {"param1": "value1", "param2": "value2"} + result = merge_override_parameters(file_params=file_params) + assert result == file_params + + def test_merge_override_parameters_precedence(self): + """Test that inline parameters take precedence over file parameters.""" + inline = {"param1": "inline_value", "param2": "inline_value2"} + file_params = {"param1": "file_value", "param3": "file_value3"} + result = merge_override_parameters(inline_params=inline, file_params=file_params) + + assert result["param1"] == "inline_value" # Inline takes precedence + assert result["param2"] == "inline_value2" + assert result["param3"] == "file_value3" + + def test_merge_override_parameters_empty(self): + """Test merging with no parameters.""" + result = merge_override_parameters() + assert result == {} + + def test_build_argo_args_override(self): + """Test building Argo args override structure.""" + params = {"prompt": "Test prompt", "temperature": "0.7"} + result = build_argo_args_override(params) + + assert "argo_args_override" in result + assert "parameters" in result["argo_args_override"] + parameters = result["argo_args_override"]["parameters"] + + assert len(parameters) == 2 + # Check that both parameters are present + param_dict = {p["name"]: p["value"] for p in parameters} + assert param_dict["prompt"] == "Test prompt" + assert param_dict["temperature"] == "0.7" + + def test_build_argo_args_override_empty(self): + """Test building override with empty parameters.""" + result = build_argo_args_override({}) + assert result == {} + + def test_validate_override_parameters_valid(self): + """Test validation with valid parameters.""" + overrides = {"param1": "value1", "param2": "value2"} + allowed = {"param1", "param2", "param3"} + + is_valid, error = validate_override_parameters(overrides, allowed) + assert is_valid is True + assert error is None + + def test_validate_override_parameters_invalid(self): + """Test validation with invalid parameters.""" + overrides = {"param1": "value1", "unknown_param": "value2"} + allowed = {"param1", "param2"} + + is_valid, error = validate_override_parameters(overrides, allowed) + assert is_valid is False + assert "unknown_param" in error + + def test_validate_override_parameters_no_rules(self): + """Test validation with no rules (accept all).""" + overrides = {"any_param": "any_value"} + + is_valid, error = validate_override_parameters(overrides, None) + assert is_valid is True + assert error is None + + def test_validate_override_parameters_empty(self): + """Test validation with empty overrides.""" + is_valid, error = validate_override_parameters({}, {"param1"}) + assert is_valid is True + assert error is None + + +class TestPipelineClientWithOverrides: + """Test cases for Pipeline client with input argument overrides.""" + + @patch('clarifai.client.pipeline.BaseClient.__init__') + def test_run_with_input_args_override(self, mock_init): + """Test pipeline run with input argument overrides.""" + from clarifai_grpc.grpc.api import resources_pb2 + from clarifai_grpc.grpc.api.status import status_code_pb2 + + from clarifai.client.pipeline import Pipeline + + mock_init.return_value = None + + pipeline = Pipeline( + pipeline_id='test-pipeline', + pipeline_version_id='test-version-123', + user_id='test-user', + app_id='test-app', + pat='test-pat', + ) + + # Mock the required attributes + pipeline.user_app_id = resources_pb2.UserAppIDSet(user_id="test-user", app_id="test-app") + pipeline.STUB = Mock() + pipeline.auth_helper = Mock() + pipeline.auth_helper.metadata = [] + + # Mock PostPipelineVersionRuns response + mock_run_response = Mock() + mock_run_response.status.code = status_code_pb2.StatusCode.SUCCESS + mock_run = Mock() + mock_run.id = 'test-run-123' + mock_run_response.pipeline_version_runs = [mock_run] + pipeline.STUB.PostPipelineVersionRuns.return_value = mock_run_response + + # Mock the monitoring method + expected_result = {"status": "success", "pipeline_version_run": mock_run} + pipeline._monitor_pipeline_run = Mock(return_value=expected_result) + + # Execute run with overrides + input_args_override = {"prompt": "Test prompt", "temperature": "0.7"} + result = pipeline.run(input_args_override=input_args_override) + + # Verify the result + assert result == expected_result + pipeline.STUB.PostPipelineVersionRuns.assert_called_once() + pipeline._monitor_pipeline_run.assert_called_once() + + # Verify the request was made (we can't check the exact structure without proto support) + call_args = pipeline.STUB.PostPipelineVersionRuns.call_args + assert call_args is not None + + @patch('clarifai.client.pipeline.BaseClient.__init__') + def test_run_without_input_args_override(self, mock_init): + """Test pipeline run works without input argument overrides (backward compatibility).""" + from clarifai_grpc.grpc.api import resources_pb2 + from clarifai_grpc.grpc.api.status import status_code_pb2 + + from clarifai.client.pipeline import Pipeline + + mock_init.return_value = None + + pipeline = Pipeline( + pipeline_id='test-pipeline', + pipeline_version_id='test-version-123', + user_id='test-user', + app_id='test-app', + pat='test-pat', + ) + + # Mock the required attributes + pipeline.user_app_id = resources_pb2.UserAppIDSet(user_id="test-user", app_id="test-app") + pipeline.STUB = Mock() + pipeline.auth_helper = Mock() + pipeline.auth_helper.metadata = [] + + # Mock PostPipelineVersionRuns response + mock_run_response = Mock() + mock_run_response.status.code = status_code_pb2.StatusCode.SUCCESS + mock_run = Mock() + mock_run.id = 'test-run-123' + mock_run_response.pipeline_version_runs = [mock_run] + pipeline.STUB.PostPipelineVersionRuns.return_value = mock_run_response + + # Mock the monitoring method + expected_result = {"status": "success", "pipeline_version_run": mock_run} + pipeline._monitor_pipeline_run = Mock(return_value=expected_result) + + # Execute run without overrides (backward compatibility) + result = pipeline.run() + + # Verify the result + assert result == expected_result + pipeline.STUB.PostPipelineVersionRuns.assert_called_once() + pipeline._monitor_pipeline_run.assert_called_once()