Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions studio/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ExecutionRequest,
ExecutionResponse,
ExecutionStatus,
MultiModalContentPart,
NodeExecutionState,
WorkflowCreateRequest,
WorkflowExecution,
Expand Down Expand Up @@ -4270,11 +4271,28 @@ def _represent_multiline_str(dumper, data):
return dumper.represent_scalar('tag:yaml.org,2002:str', data)


def _represent_multimodal_content_part(dumper, data):
"""Represent MultiModalContentPart as a dict, excluding None values."""
d = {}
if data.type:
d['type'] = data.type
if data.text is not None:
d['text'] = data.text
if data.audio_url is not None:
d['audio_url'] = data.audio_url
if data.image_url is not None:
d['image_url'] = data.image_url
if data.video_url is not None:
d['video_url'] = data.video_url
return dumper.represent_dict(d)


def _get_yaml_dumper():
"""Get a custom YAML dumper that properly formats multiline strings."""
class CustomDumper(yaml.SafeDumper):
pass
CustomDumper.add_representer(str, _represent_multiline_str)
CustomDumper.add_representer(MultiModalContentPart, _represent_multimodal_content_part)
return CustomDumper


Expand Down
74 changes: 74 additions & 0 deletions studio/frontend/src/lib/utils/stateVariables.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,54 @@ const FRAMEWORK_VARIABLES: StateVariable[] = [
{ name: 'id', source: 'framework', description: 'Record ID from data source' },
];

/**
* Extract fields added by transformations (e.g., AddNewFieldTransform, RenameColumnsTransform)
*/
function extractTransformationFields(transformations?: any[]): string[] {
if (!transformations || !Array.isArray(transformations)) return [];

const fields: string[] = [];

for (const transform of transformations) {
const params = transform.params || {};

// AddNewFieldTransform adds new columns from mapping keys
// RenameColumnsTransform renames columns - mapping values are new names
if (params.mapping && typeof params.mapping === 'object') {
const transformName = transform.transform || '';

if (transformName.includes('AddNewField') || transformName.includes('AddColumn')) {
// For AddNewFieldTransform, mapping keys are the new column names
fields.push(...Object.keys(params.mapping));
} else if (transformName.includes('Rename')) {
// For RenameColumnsTransform, mapping values are the new column names
fields.push(...Object.values(params.mapping).filter((v): v is string => typeof v === 'string'));
} else {
// Generic case - assume mapping keys are new fields
fields.push(...Object.keys(params.mapping));
}
}

// Some transforms use 'columns' or 'new_columns' param
if (params.new_columns && Array.isArray(params.new_columns)) {
fields.push(...params.new_columns);
}
if (params.column && typeof params.column === 'string') {
fields.push(params.column);
}
}

return fields;
}

/**
* Extract column names from data source configuration.
* Uses fetchedColumns if available (from API), otherwise falls back to inline data.
*
* When aliases are defined in multiple data sources, columns are transformed to
* {alias}->{column} format to match the actual variable names used in prompts.
*
* Also extracts fields added by transformations (e.g., AddNewFieldTransform).
*
* @param dataConfig The data source configuration
* @param fetchedColumns Optional pre-fetched columns from the API (with transformations applied)
Expand All @@ -67,6 +109,22 @@ function extractDataColumns(dataConfig?: DataSourceConfig, fetchedColumns?: stri
});
}
}

// Also add transformation-added fields for single source
if (sources.length > 0 && sources[0].transformations) {
const transformFields = extractTransformationFields(sources[0].transformations);
for (const fieldName of transformFields) {
if (!variables.some(v => v.name === fieldName)) {
variables.push({
name: fieldName,
source: 'data',
sourceNode: 'DATA',
description: 'Field added by transformation'
});
}
}
}

return variables;
}

Expand Down Expand Up @@ -113,6 +171,22 @@ function extractDataColumns(dataConfig?: DataSourceConfig, fetchedColumns?: stri
});
}
}

// Add transformation-added fields for this source
if (source.transformations) {
const transformFields = extractTransformationFields(source.transformations);
for (const fieldName of transformFields) {
const varName = (hasMultipleSources && alias) ? `${alias}->${fieldName}` : fieldName;
if (!variables.some(v => v.name === varName)) {
variables.push({
name: varName,
source: 'data',
sourceNode: 'DATA',
description: `Field added by transformation${alias ? ` (${alias})` : ''}`
});
}
}
}
}

return variables;
Expand Down
6 changes: 6 additions & 0 deletions studio/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ def _build_node(
# Extract output_keys for LLM/multi_llm nodes
output_keys = node_config.get("output_keys")

# Extract sampler_config for weighted_sampler nodes
sampler_config = None
if node_type == NodeType.WEIGHTED_SAMPLER and "attributes" in node_config:
sampler_config = {"attributes": node_config["attributes"]}

return WorkflowNode(
id=node_name,
node_type=node_type,
Expand All @@ -311,6 +316,7 @@ def _build_node(
inner_graph=inner_graph,
node_config_map=node_config_map,
function_path=node_config.get("lambda") or node_config.get("function"),
sampler_config=sampler_config,
metadata={
"original_config": node_config,
},
Expand Down
34 changes: 34 additions & 0 deletions sygra/core/base_task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,36 @@ def init_graph(self) -> StateGraph:
graph_builder = LangGraphBuilder(self.graph_config)
return cast(StateGraph, graph_builder.build())

def _process_feilds(self, data, select_columns: Optional[list] = None):
"""
Iterate through each record and filter fields based on select_columns
If select_columns is not defined, then return all fields
Also perform data transformation for unsupported datatype, which fails to write because of serialization error
Currently supported non-serialized data type: ndarray
"""
select_columns = select_columns or []
final_data = []
filter_column = select_columns is not None and len(select_columns) > 0
# make sure id column is preserved, if filter_column are defined
if filter_column and "id" not in select_columns:
select_columns.append("id")

for record in data:
new_record = {}
for k, v in record.items():
# skip the fields not in select_columns, when select list is defined
if filter_column and k not in select_columns:
continue
# convert the value to list if it is numpy array
if isinstance(v, np.ndarray):
v = v.tolist()
# store the updated key-value
new_record[k] = v
# now add into final dataset
if len(new_record) > 0:
final_data.append(new_record)
return final_data

# Initialize and return the dataset for the task
def init_dataset(
self,
Expand All @@ -293,6 +323,10 @@ def init_dataset(

# Configure and load source data
data = self._load_source_data(data_config)
# get select fields if defined
select_fields = data_config.get("source", {}).get("fields", [])
# select only required fields
data = self._process_feilds(data, select_fields)

# Infer features for IterableDataset if they're missing/unknown
if isinstance(data, datasets.IterableDataset):
Expand Down