Skip to content

Commit 4ab8564

Browse files
SyGra studio minor issues (#118)
* workflow save issue and sampler node issue * support fields filter and ndarray conversion * import cleanup * lint error * lint error * lint error --------- Co-authored-by: Vipul Mittal <[email protected]>
1 parent c149ab9 commit 4ab8564

File tree

4 files changed

+132
-0
lines changed

4 files changed

+132
-0
lines changed

studio/api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
ExecutionRequest,
4242
ExecutionResponse,
4343
ExecutionStatus,
44+
MultiModalContentPart,
4445
NodeExecutionState,
4546
WorkflowCreateRequest,
4647
WorkflowExecution,
@@ -4270,11 +4271,28 @@ def _represent_multiline_str(dumper, data):
42704271
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
42714272

42724273

4274+
def _represent_multimodal_content_part(dumper, data):
4275+
"""Represent MultiModalContentPart as a dict, excluding None values."""
4276+
d = {}
4277+
if data.type:
4278+
d['type'] = data.type
4279+
if data.text is not None:
4280+
d['text'] = data.text
4281+
if data.audio_url is not None:
4282+
d['audio_url'] = data.audio_url
4283+
if data.image_url is not None:
4284+
d['image_url'] = data.image_url
4285+
if data.video_url is not None:
4286+
d['video_url'] = data.video_url
4287+
return dumper.represent_dict(d)
4288+
4289+
42734290
def _get_yaml_dumper():
42744291
"""Get a custom YAML dumper that properly formats multiline strings."""
42754292
class CustomDumper(yaml.SafeDumper):
42764293
pass
42774294
CustomDumper.add_representer(str, _represent_multiline_str)
4295+
CustomDumper.add_representer(MultiModalContentPart, _represent_multimodal_content_part)
42784296
return CustomDumper
42794297

42804298

studio/frontend/src/lib/utils/stateVariables.ts

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,54 @@ const FRAMEWORK_VARIABLES: StateVariable[] = [
3636
{ name: 'id', source: 'framework', description: 'Record ID from data source' },
3737
];
3838

39+
/**
40+
* Extract fields added by transformations (e.g., AddNewFieldTransform, RenameColumnsTransform)
41+
*/
42+
function extractTransformationFields(transformations?: any[]): string[] {
43+
if (!transformations || !Array.isArray(transformations)) return [];
44+
45+
const fields: string[] = [];
46+
47+
for (const transform of transformations) {
48+
const params = transform.params || {};
49+
50+
// AddNewFieldTransform adds new columns from mapping keys
51+
// RenameColumnsTransform renames columns - mapping values are new names
52+
if (params.mapping && typeof params.mapping === 'object') {
53+
const transformName = transform.transform || '';
54+
55+
if (transformName.includes('AddNewField') || transformName.includes('AddColumn')) {
56+
// For AddNewFieldTransform, mapping keys are the new column names
57+
fields.push(...Object.keys(params.mapping));
58+
} else if (transformName.includes('Rename')) {
59+
// For RenameColumnsTransform, mapping values are the new column names
60+
fields.push(...Object.values(params.mapping).filter((v): v is string => typeof v === 'string'));
61+
} else {
62+
// Generic case - assume mapping keys are new fields
63+
fields.push(...Object.keys(params.mapping));
64+
}
65+
}
66+
67+
// Some transforms use 'columns' or 'new_columns' param
68+
if (params.new_columns && Array.isArray(params.new_columns)) {
69+
fields.push(...params.new_columns);
70+
}
71+
if (params.column && typeof params.column === 'string') {
72+
fields.push(params.column);
73+
}
74+
}
75+
76+
return fields;
77+
}
78+
3979
/**
4080
* Extract column names from data source configuration.
4181
* Uses fetchedColumns if available (from API), otherwise falls back to inline data.
4282
*
4383
* When aliases are defined in multiple data sources, columns are transformed to
4484
* {alias}->{column} format to match the actual variable names used in prompts.
85+
*
86+
* Also extracts fields added by transformations (e.g., AddNewFieldTransform).
4587
*
4688
* @param dataConfig The data source configuration
4789
* @param fetchedColumns Optional pre-fetched columns from the API (with transformations applied)
@@ -67,6 +109,22 @@ function extractDataColumns(dataConfig?: DataSourceConfig, fetchedColumns?: stri
67109
});
68110
}
69111
}
112+
113+
// Also add transformation-added fields for single source
114+
if (sources.length > 0 && sources[0].transformations) {
115+
const transformFields = extractTransformationFields(sources[0].transformations);
116+
for (const fieldName of transformFields) {
117+
if (!variables.some(v => v.name === fieldName)) {
118+
variables.push({
119+
name: fieldName,
120+
source: 'data',
121+
sourceNode: 'DATA',
122+
description: 'Field added by transformation'
123+
});
124+
}
125+
}
126+
}
127+
70128
return variables;
71129
}
72130

@@ -113,6 +171,22 @@ function extractDataColumns(dataConfig?: DataSourceConfig, fetchedColumns?: stri
113171
});
114172
}
115173
}
174+
175+
// Add transformation-added fields for this source
176+
if (source.transformations) {
177+
const transformFields = extractTransformationFields(source.transformations);
178+
for (const fieldName of transformFields) {
179+
const varName = (hasMultipleSources && alias) ? `${alias}->${fieldName}` : fieldName;
180+
if (!variables.some(v => v.name === varName)) {
181+
variables.push({
182+
name: varName,
183+
source: 'data',
184+
sourceNode: 'DATA',
185+
description: `Field added by transformation${alias ? ` (${alias})` : ''}`
186+
});
187+
}
188+
}
189+
}
116190
}
117191

118192
return variables;

studio/graph_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ def _build_node(
295295
# Extract output_keys for LLM/multi_llm nodes
296296
output_keys = node_config.get("output_keys")
297297

298+
# Extract sampler_config for weighted_sampler nodes
299+
sampler_config = None
300+
if node_type == NodeType.WEIGHTED_SAMPLER and "attributes" in node_config:
301+
sampler_config = {"attributes": node_config["attributes"]}
302+
298303
return WorkflowNode(
299304
id=node_name,
300305
node_type=node_type,
@@ -311,6 +316,7 @@ def _build_node(
311316
inner_graph=inner_graph,
312317
node_config_map=node_config_map,
313318
function_path=node_config.get("lambda") or node_config.get("function"),
319+
sampler_config=sampler_config,
314320
metadata={
315321
"original_config": node_config,
316322
},

sygra/core/base_task_executor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,36 @@ def init_graph(self) -> StateGraph:
282282
graph_builder = LangGraphBuilder(self.graph_config)
283283
return cast(StateGraph, graph_builder.build())
284284

285+
def _process_feilds(self, data, select_columns: Optional[list] = None):
286+
"""
287+
Iterate through each record and filter fields based on select_columns
288+
If select_columns is not defined, then return all fields
289+
Also perform data transformation for unsupported datatype, which fails to write because of serialization error
290+
Currently supported non-serialized data type: ndarray
291+
"""
292+
select_columns = select_columns or []
293+
final_data = []
294+
filter_column = select_columns is not None and len(select_columns) > 0
295+
# make sure id column is preserved, if filter_column are defined
296+
if filter_column and "id" not in select_columns:
297+
select_columns.append("id")
298+
299+
for record in data:
300+
new_record = {}
301+
for k, v in record.items():
302+
# skip the fields not in select_columns, when select list is defined
303+
if filter_column and k not in select_columns:
304+
continue
305+
# convert the value to list if it is numpy array
306+
if isinstance(v, np.ndarray):
307+
v = v.tolist()
308+
# store the updated key-value
309+
new_record[k] = v
310+
# now add into final dataset
311+
if len(new_record) > 0:
312+
final_data.append(new_record)
313+
return final_data
314+
285315
# Initialize and return the dataset for the task
286316
def init_dataset(
287317
self,
@@ -293,6 +323,10 @@ def init_dataset(
293323

294324
# Configure and load source data
295325
data = self._load_source_data(data_config)
326+
# get select fields if defined
327+
select_fields = data_config.get("source", {}).get("fields", [])
328+
# select only required fields
329+
data = self._process_feilds(data, select_fields)
296330

297331
# Infer features for IterableDataset if they're missing/unknown
298332
if isinstance(data, datasets.IterableDataset):

0 commit comments

Comments
 (0)