Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions clarifai/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ def upload(path, no_lockfile):
default=False,
help='Monitor an existing pipeline run instead of starting a new one. Requires pipeline_version_run_id.',
)
@click.option(
'--set',
'set_params',
multiple=True,
help='Set input argument override (can be used multiple times). Format: key=value. Example: --set prompt="Hello" --set temperature="0.7"',
)
@click.option(
'--overrides-file',
type=click.Path(exists=True),
help='Path to JSON file containing input argument overrides. Inline --set parameters take precedence over file values.',
)
@click.pass_context
def run(
ctx,
Expand All @@ -100,15 +111,63 @@ def run(
monitor_interval,
log_file,
monitor,
set_params,
overrides_file,
):
"""Run a pipeline and monitor its progress."""
"""Run a pipeline and monitor its progress.

Examples:

# Run with inline parameter overrides
clarifai pipeline run --compute_cluster_id=cc1 --nodepool_id=np1 \\
--set prompt="Summarize this" --set temperature="0.7"

# Run with file-based overrides
clarifai pipeline run --compute_cluster_id=cc1 --nodepool_id=np1 \\
--overrides-file overrides.json

# Combine both (inline takes precedence)
clarifai pipeline run --compute_cluster_id=cc1 --nodepool_id=np1 \\
--overrides-file overrides.json --set prompt="Override prompt"
"""
import json

from clarifai.client.pipeline import Pipeline
from clarifai.utils.cli import from_yaml, validate_context
from clarifai.utils.pipeline_overrides import (
load_overrides_from_file,
merge_override_parameters,
parse_set_parameter,
)

validate_context(ctx)

# Parse input argument overrides
input_args_override = None
try:
# Parse inline --set parameters
inline_overrides = {}
if set_params:
for param in set_params:
key, value = parse_set_parameter(param)
inline_overrides[key] = value
logger.info(f"Inline override: {key}={value}")

# Load file-based overrides
file_overrides = {}
if overrides_file:
file_overrides = load_overrides_from_file(overrides_file)
logger.info(f"Loaded {len(file_overrides)} overrides from {overrides_file}")

# Merge overrides (inline takes precedence)
if inline_overrides or file_overrides:
input_args_override = merge_override_parameters(inline_overrides, file_overrides)
logger.info(f"Final overrides: {input_args_override}")

except (ValueError, FileNotFoundError) as e:
logger.error(f"Error processing input argument overrides: {e}")
raise click.Abort()

# Try to load from config-lock.yaml first if no config is specified
lockfile_path = os.path.join(os.getcwd(), "config-lock.yaml")
if not config and os.path.exists(lockfile_path):
Expand Down Expand Up @@ -205,7 +264,11 @@ def run(
result = pipeline.monitor_only(timeout=timeout, monitor_interval=monitor_interval)
else:
# Start new pipeline run and monitor it
result = pipeline.run(timeout=timeout, monitor_interval=monitor_interval)
result = pipeline.run(
timeout=timeout,
monitor_interval=monitor_interval,
input_args_override=input_args_override,
)
click.echo(json.dumps(result, indent=2, default=str))


Expand Down
58 changes: 56 additions & 2 deletions clarifai/client/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from clarifai.urls.helper import ClarifaiUrlHelper
from clarifai.utils.constants import DEFAULT_BASE
from clarifai.utils.logging import logger
from clarifai.utils.pipeline_overrides import build_argo_args_override


class Pipeline(Lister, BaseClient):
Expand Down Expand Up @@ -100,16 +101,39 @@ def __init__(
nodepool_id=self.nodepool_id,
)

def run(self, inputs: List = None, timeout: int = 3600, monitor_interval: int = 10) -> Dict:
def run(
self,
inputs: List = None,
timeout: int = 3600,
monitor_interval: int = 10,
input_args_override: Optional[Dict[str, str]] = None,
) -> Dict:
"""Run the pipeline and monitor its progress.

Args:
inputs (List): List of inputs to run the pipeline with. If None, runs without inputs.
timeout (int): Maximum time to wait for completion in seconds. Default 3600 (1 hour).
monitor_interval (int): Interval between status checks in seconds. Default 10.
input_args_override (Optional[Dict[str, str]]): Dictionary of parameter overrides for this run.
Keys are parameter names, values are parameter values as strings.
Example: {"prompt": "Summarize this", "temperature": "0.7"}

Returns:
Dict: The pipeline run result.
Dict: The pipeline run result including orchestration_spec if available.

Example:
>>> pipeline = Pipeline(
... pipeline_id='my-pipeline',
... pipeline_version_id='v1',
... user_id='user123',
... app_id='app456',
... nodepool_id='nodepool1',
... compute_cluster_id='cluster1',
... pat='your-pat'
... )
>>> result = pipeline.run(
... input_args_override={"prompt": "Summarize", "temperature": "0.7"}
... )
"""
# Create a new pipeline version run
pipeline_version_run = resources_pb2.PipelineVersionRun()
Expand All @@ -125,6 +149,36 @@ def run(self, inputs: List = None, timeout: int = 3600, monitor_interval: int =
)
pipeline_version_run.nodepools.extend([nodepool])

# Add input_args_override if provided
if input_args_override:
logger.info(f"Applying input argument overrides: {input_args_override}")
override_dict = build_argo_args_override(input_args_override)

# When proto messages are available, this will be:
# pipeline_version_run.input_args_override.CopyFrom(override_proto)
# For now, we store it in a generic field if available
if hasattr(pipeline_version_run, 'input_args_override'):
# Proto field exists - use it directly
try:
from google.protobuf import json_format as jf

jf.ParseDict(override_dict, pipeline_version_run.input_args_override)
except Exception as e:
logger.warning(
f"Could not set input_args_override proto field: {e}. "
"This may require an updated clarifai-grpc version."
)
else:
# Proto field doesn't exist yet - store in metadata for future use
# This allows forward compatibility
logger.debug(
"input_args_override field not yet available in proto. "
"Override will be applied when clarifai-grpc is updated."
)
# Store for potential future use via custom metadata
if not hasattr(self, '_pending_overrides'):
self._pending_overrides = override_dict

run_request = service_pb2.PostPipelineVersionRunsRequest()
run_request.user_app_id.CopyFrom(self.user_app_id)
run_request.pipeline_id = self.pipeline_id
Expand Down
134 changes: 134 additions & 0 deletions clarifai/utils/pipeline_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Utilities for handling pipeline input argument overrides."""

import json
from typing import Any, Dict, Optional


def parse_set_parameter(param_str: str) -> tuple[str, str]:
"""Parse a --set parameter string into key-value pair.

Args:
param_str: Parameter string in format "key=value"

Returns:
Tuple of (key, value)

Raises:
ValueError: If parameter string is not in correct format
"""
if '=' not in param_str:
raise ValueError(
f"Invalid --set parameter format: '{param_str}'. Expected format: key=value"
)

key, value = param_str.split('=', 1)
key = key.strip()
value = value.strip()

if not key:
raise ValueError(f"Empty key in --set parameter: '{param_str}'")

return key, value


def load_overrides_from_file(file_path: str) -> Dict[str, str]:
"""Load parameter overrides from a JSON file.

Args:
file_path: Path to JSON file containing overrides

Returns:
Dictionary of parameter name to value mappings

Raises:
FileNotFoundError: If file doesn't exist
ValueError: If file is not valid JSON or doesn't contain a dictionary
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in overrides file '{file_path}': {e}") from e

if not isinstance(data, dict):
raise ValueError(
f"Overrides file '{file_path}' must contain a JSON object (dictionary), got {type(data).__name__}"
)

# Convert all values to strings (Argo convention)
return {str(k): str(v) for k, v in data.items()}


def merge_override_parameters(
inline_params: Optional[Dict[str, str]] = None, file_params: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""Merge inline and file-based parameter overrides.

Inline parameters take precedence over file parameters.

Args:
inline_params: Parameters from --set flags
file_params: Parameters from --overrides-file

Returns:
Merged dictionary of parameters
"""
result = {}

if file_params:
result.update(file_params)

if inline_params:
result.update(inline_params)

return result


def build_argo_args_override(parameters: Dict[str, str]) -> Dict[str, Any]:
"""Build an ArgoArgsOverride structure from parameter dictionary.

This creates a dictionary structure compatible with the proto message
format that will be used when the proto is available.

Args:
parameters: Dictionary of parameter name to value mappings

Returns:
Dictionary structure compatible with OrchestrationArgsOverride proto
"""
if not parameters:
return {}

# Build structure compatible with proto message format
# This will be serialized to proto when clarifai-grpc is updated
return {
'argo_args_override': {
'parameters': [{'name': name, 'value': value} for name, value in parameters.items()]
}
}


def validate_override_parameters(
override_params: Dict[str, str], allowed_params: Optional[set] = None
) -> tuple[bool, Optional[str]]:
"""Validate that override parameters are allowed.

Args:
override_params: Parameters to validate
allowed_params: Set of allowed parameter names. If None, validation is skipped.

Returns:
Tuple of (is_valid, error_message). error_message is None if valid.
"""
if not override_params:
return True, None

if allowed_params is None:
# No validation rules provided, accept all parameters
return True, None

unknown_params = set(override_params.keys()) - allowed_params
if unknown_params:
return False, f"Unknown parameters: {', '.join(sorted(unknown_params))}"

return True, None
Loading
Loading